From 5538d7c2a202c657a68fe7c9629754c7efef4720 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 24 Mar 2025 11:14:58 +0000 Subject: [PATCH 01/14] Templated LHS QAI8 IGEMM packing kernel Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + .../pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c | 344 ++++++++++++++++++ .../pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h | 56 +++ 4 files changed, 402 insertions(+) create mode 100644 kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5aecb07e..176ef402 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,6 +181,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME + kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 3796d88d..ceb8fa16 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -114,6 +114,7 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ + "pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", diff --git a/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c b/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c new file mode 100644 index 00000000..a138a825 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c @@ -0,0 +1,344 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_kr = 4; + +// Max size is calculated from maximum vector length +// of 256 bytes, divided by sizeof(int8_t) * kai_kr +#define BLOCK_HEIGHT_MAX 64 + +static size_t kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme(void) { + return kai_mr * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme(void) { + return kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme(); +} + +size_t kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme() == 0); + + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t); +} + +size_t kai_get_lhs_packed_size_lhs_igemm_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme()); + return kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme(m_end, k_chunk_count, k_chunk_length); +} + +void kai_run_lhs_igemm_pack_x8p2vlx4_x8_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* zero, void* lhs_packed) { + KAI_ASSUME(lhs_ptrs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + const size_t m_step = kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme(); + const size_t row_offset = 0; + const size_t width = k_chunk_length; + + KAI_ASSERT(m_step <= BLOCK_HEIGHT_MAX); + const uint8_t* in[BLOCK_HEIGHT_MAX]; + + uint8_t* out_base = lhs_packed; + for (size_t i_m = 0; i_m < m; i_m += m_step) { + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; i_k_chunk += 1) { + const size_t height = KAI_MIN(m - i_m, m_step); + void* out = out_base; + for (size_t y = 0; y < height; y += 1) { + KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); + in[y] = *(lhs_ptrs + i_k_chunk + (i_m + y) * k_chunk_count); + if (in[y] != zero) { + in[y] += lhs_ptr_offset; + } + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x23, %x[width]\n" + "mov x21, %x[width]\n" + "cntb x20\n" + "incb x23\n" + "sub x7, x20, #0x1\n" + "cntw x8\n" + "sub x23, x23, #0x1\n" + "ands x7, x21, x7\n" + "udiv x23, x23, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "lsl x22, %x[height], #0x1\n" // height * 2 + "lsl x21, x8, #0x1\n" + "sub x20, x23, #0x1\n" + "add x7, x7, #0x3\n" + "sub x17, x8, #0x2\n" + "whilelt p9.b, XZR, x22\n" + "whilelt p8.b, x21, x22\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x23, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x2\n" + "ptrue p11.s\n" + "ldr x24, [x11, #0x8]\n" + "zip1 p10.b, p9.b, p8.b\n" + "mov x23, %x[row_offset]\n" + "ldr x21, [x10, #0x8]\n" + "mov x22, %x[out]\n" + "whilelt p9.b, x16, %x[width]\n" + "whilelt p8.b, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" + ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" + ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" + ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" + ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" + "add x12, x12, #0x8\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #2\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" + ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" + ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" + ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" + "ldr x27, [x11, #0x0]\n" + "incb x16\n" + ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "incb x23\n" + "cbz x20, 8f\n" + "mov x20, x20\n" + "3:" // K loop: Main loop + "whilelt p8.b, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" + ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" + ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" + ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" + ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" + ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" + ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" + ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" + "add x15, x15, #0x8\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" + ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" + ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" + ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" + ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" + ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" + ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" + "whilelt p9.b, x16, %x[width]\n" + ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" + "incb x16\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "incb x23\n" + "whilelt p8.b, x16, %x[width]\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" + ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" + ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" + ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" + ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" + ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" + ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" + ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + "add x13, x13, #0x8\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" + ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" + ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" + ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" + ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" + ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" + ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" + "whilelt p9.b, x16, %x[width]\n" + ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + "subs x20, x20, #0x1\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "incb x16\n" + "incb x23\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.b, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25306d23 // psel p3.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d22 // psel p2.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25356141 // psel p1.b, p8.b/Z, p10.b[w13, #2]\n" + ".inst 0x253d6140 // psel p0.b, p8.b/Z, p10.b[w13, #3]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe01726a2 // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x23]\n" + ".inst 0xe0172283 // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x23]\n" + "add x13, x13, #0x4\n" + "blt 9b\n" + "whilelt p9.b, x16, %x[width]\n" + "whilelt p8.b, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" + "add x20, x20, #0x4\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.b, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", + "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", + "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", + "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + out_base += m_step * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t); + } + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h b/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h new file mode 100644 index 00000000..3b85357b --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h @@ -0,0 +1,56 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return Step size for row index +size_t kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme(void); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_igemm_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); + +/// Pack the LHS matrix for use with indirect matrix multiplication +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs array, excluding zero pointers. +/// @param[in] zero Pointer to a zero element. Used to check for padding pointers in @ref lhs_ptrs. +/// @param[out] lhs_packed Packed LHS matrix. +void kai_run_lhs_igemm_pack_x8p2vlx4_x8_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* zero, void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus -- GitLab From f9a78032e30153f25a286badaab09e311719be2a Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Fri, 28 Mar 2025 16:38:04 +0000 Subject: [PATCH 02/14] Add Templated RHS Packing Kernel for IGEMM QSI8 Signed-off-by: Mohammed Suhail Munshi --- CMakeLists.txt | 1 + kai/kai_common.h | 3 + kai/ukernels/matmul/BUILD.bazel | 1 + ...ack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c | 276 ++++++++++++++++++ ...ack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h | 90 ++++++ 5 files changed, 371 insertions(+) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 176ef402..f4962088 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c diff --git a/kai/kai_common.h b/kai/kai_common.h index c1cb1eca..47e00bfa 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -54,6 +54,9 @@ extern "C" { #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) +/// Largest supported SME vector length in bytes +#define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum) + /// Gets the version of the project in the Major.Minor.Patch semantic versioning format. /// /// @return Project version as a string literal. diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index ceb8fa16..4d71c2f5 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -119,6 +119,7 @@ SME_KERNELS = [ "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", + "pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", diff --git a/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c new file mode 100644 index 00000000..bba57725 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -0,0 +1,276 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. +#include "kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_nr = 2; +static const size_t kai_kr = 4; +static const size_t kai_num_bytes_input = sizeof(uint8_t); +static const size_t kai_num_bytes_output = sizeof(uint8_t); +static const size_t kai_num_bytes_bias = sizeof(int32_t); +static const size_t kai_num_bytes_scale = sizeof(float32_t); + +#define MAX_N_STEP (KAI_SME_VEC_LENGTH_MAX_BYTES / (kai_kr * kai_nr)) + +size_t kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_rhs_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_scale; +} + +static size_t kai_get_rhs_packed_stride_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t k_chunk_count, size_t k_chunk_length) { + return kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * + (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * kai_num_bytes_output + + kai_num_bytes_scale); +} + +size_t kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(); + return block_idx * + kai_get_rhs_packed_stride_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); +} + +size_t kai_get_rhs_packed_size_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length) { + const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + return kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + n_rounded_up, k_chunk_count, k_chunk_length); +} + +void kai_run_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(params != NULL); + + size_t height = k_chunk_length; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_row_stride; + + KAI_ASSERT(kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); + uint8_t pad_row[MAX_N_STEP]; + if (height % kai_kr) { + memset(pad_row, 0, MAX_N_STEP * sizeof(uint8_t)); + } + + size_t out_stride = + kai_get_rhs_packed_stride_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); + const int32_t input_zero_point = params->lhs_zero_point; + const float scale_multiplier = params->scale_multiplier; + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x12, %x[out]\n" + "mov x11, %x[k_chunk_count]\n" + "ptrue p2.b\n" + "incb %x[out], ALL, MUL #2\n" + "1:" // Chunk Loop + "mov x10, %x[height]\n" + "cmp x10, #0x8\n" + "blt 5f\n" + "2:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub x10, x10, #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x25, x26, %x[in_stride]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "3:" // Main row loop: Column loop + "whilelt p0.b, XZR, x24\n" + "decw x24, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "cmp x24, #0x0\n" + "incd x9, ALL, MUL #4\n" + "ld1b { z22.b }, p0/Z, [x27]\n" + "incd x27, ALL, MUL #4\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "incd x26, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x25]\n" + "incd x25, ALL, MUL #4\n" + "ld1b { z20.b }, p0/Z, [x23]\n" + "incd x23, ALL, MUL #4\n" + "ld1b { z19.b }, p0/Z, [x22]\n" + "zip1 z21.b, z18.b, z17.b\n" + "incd x22, ALL, MUL #4\n" + "ld1b { z18.b }, p0/Z, [x21]\n" + "zip1 z17.b, z22.b, z16.b\n" + "incd x21, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "incd x20, ALL, MUL #4\n" + "zip1 z20.b, z20.b, z18.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z19.b, z21.b, z17.b\n" + "zip2 z18.b, z21.b, z17.b\n" + "zip1 z17.b, z20.b, z16.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "st1b { z19.b }, p2, [x28]\n" + "st1b { z18.b }, p2, [x28, #1, MUL VL]\n" + "st1b { z17.b }, p2, [x28, #2, MUL VL]\n" + "st1b { z16.b }, p2, [x28, #3, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 3b\n" + "cmp x10, #0x8\n" + "addvl %x[out], %x[out], #4\n" + "bge 2b\n" + "cbz x10, 9f\n" + "5:" // Main loop skip + "6:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cmp x10, #0x3\n" + "add x27, x9, %x[in_stride]\n" + "cntw x24, ALL, MUL #2\n" + "add x26, x27, %x[in_stride]\n" + "csel x23, x24, XZR, GT\n" + "add x25, x26, %x[in_stride]\n" + "csel x22, x24, XZR, GE\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x28, %x[out]\n" + "csel %x[in], %x[in], x25, GT\n" + "csel x25, x25, %x[pad_row], GT\n" + "csel %x[in], %x[in], x26, GE\n" + "csel x26, x26, %x[pad_row], GE\n" + "cmp x10, #0x1\n" + "sub x10, x10, #0x4\n" + "csel %x[in], %x[in], x27, GT\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x21, x24, XZR, GT\n" + "mov x20, %x[width]\n" + "7:" // Tail row loop: Column loop + "whilelt p0.b, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "add x9, x9, x24\n" + "ld1b { z19.b }, p0/Z, [x27]\n" + "add x27, x27, x21\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "add x26, x26, x22\n" + "ld1b { z16.b }, p0/Z, [x25]\n" + "add x25, x25, x23\n" + "zip1 z18.b, z18.b, z17.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z16.b, z18.b, z16.b\n" + "st1b { z17.b }, p2, [x28]\n" + "st1b { z16.b }, p2, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 7b\n" + "cmp x10, #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 6b\n" + "9:" // Done + "sub x11, x11, #0x1\n" + "cbnz x11, 1b\n" + "mov x22, %x[out]\n" + "mov x21, %x[width]\n" + "dup z18.s, %w[scale_multiplier]\n" + "cbz %x[scale], 11f\n" + "10:" // Scale: Full loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[scale]]\n" + "cmp x21, #0x0\n" + "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" + "incb %x[scale], ALL, MUL #2\n" + "fmul z17.s, z17.s, z18.s\n" + "fmul z16.s, z16.s, z18.s\n" + "st1w { z17.s }, p2, [x22]\n" + "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Scale: Done + "cbz %x[width], 14f\n" + "cbz %x[height], 14f\n" + "dup z21.s, %w[input_zero_point]\n" + "add x25, %x[height], #0x3\n" + "cntw x24, ALL, MUL #2\n" + "mov z20.b, #0x1\n" + "lsr x25, x25, #0x2\n" + "mov x23, %x[width]\n" + "mul x25, %x[k_chunk_count], x25\n" + "addvl x22, x12, #2\n" + "neg z21.s, p2/M, z21.s\n" + "12:" // Bias: N loop + "mov x21, x22\n" + "mov x20, x25\n" + "mov z19.s, #0x0\n" + "mov z18.s, #0x0\n" + "13:" // Bias: K loop + "ld1b { z17.b }, p2/Z, [x21]\n" + "subs x20, x20, #0x1\n" + "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "sdot z19.s, z17.b, z20.b\n" + "sdot z18.s, z16.b, z20.b\n" + "bgt 13b\n" + "mov x20, x23\n" + "add x22, x22, %x[out_stride]\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[bias]]\n" + "subs x23, x23, x24\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "addvl %x[bias], %x[bias], #2\n" + "mla z17.s, p2/M, z19.s, z21.s\n" + "mla z16.s, p2/M, z18.s, z21.s\n" + "st1w { z17.s }, p2, [x12]\n" + "st1w { z16.s }, p2, [x12, #1, MUL VL]\n" + "add x12, x12, %x[out_stride]\n" + "bgt 12b\n" + "14:" // Bias: Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) + : [height] "r"(height), [in_stride] "r"(in_stride), [input_zero_point] "r"(input_zero_point), + [k_chunk_count] "r"(k_chunk_count), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), + [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", + "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h new file mode 100644 index 00000000..76a69241 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -0,0 +1,90 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return Step size for column index. +size_t kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the scale buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of columns. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Scale: @ref kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] params Extra packing parameters. +void kai_run_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus -- GitLab From 174f52cfbc8bc28eac59bdddd824c82f3eab67a1 Mon Sep 17 00:00:00 2001 From: Felix Johnny Thomasmathibalan Date: Fri, 28 Mar 2025 16:55:59 +0000 Subject: [PATCH 03/14] Rename RHS igemm to imatmul Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...ck_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c} | 36 +++++++++---------- ...ck_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h} | 22 ++++++------ 4 files changed, 31 insertions(+), 31 deletions(-) rename kai/ukernels/matmul/pack/{kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c => kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c} (85%) rename kai/ukernels/matmul/pack/{kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h => kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h} (75%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f4962088..7355f2d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,7 +185,7 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c - kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 4d71c2f5..68e503fa 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -119,7 +119,7 @@ SME_KERNELS = [ "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", - "pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", + "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", diff --git a/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c similarity index 85% rename from kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c rename to kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index bba57725..b8bc2cbb 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -10,7 +10,7 @@ #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. -#include "kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" +#include "kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include #include @@ -27,48 +27,48 @@ static const size_t kai_num_bytes_scale = sizeof(float32_t); #define MAX_N_STEP (KAI_SME_VEC_LENGTH_MAX_BYTES / (kai_kr * kai_nr)) -size_t kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { +size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_rhs_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { - KAI_ASSUME(n_idx % kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); return n_idx * kai_num_bytes_input; } -size_t kai_get_bias_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { return n_idx * kai_num_bytes_bias; } -size_t kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { +size_t kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { return n_idx * kai_num_bytes_scale; } -static size_t kai_get_rhs_packed_stride_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +static size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t k_chunk_count, size_t k_chunk_length) { - return kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * + return kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * kai_num_bytes_output + kai_num_bytes_scale); } -size_t kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { - KAI_ASSUME(n_idx % kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); - const size_t block_idx = n_idx / kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(); + const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(); return block_idx * - kai_get_rhs_packed_stride_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); } -size_t kai_get_rhs_packed_size_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length) { - const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); - return kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( n_rounded_up, k_chunk_count, k_chunk_length); } -void kai_run_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params) { KAI_ASSUME(rhs != NULL); @@ -82,14 +82,14 @@ void kai_run_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( void* out = rhs_packed; const size_t in_stride = rhs_row_stride; - KAI_ASSERT(kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); + KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); uint8_t pad_row[MAX_N_STEP]; if (height % kai_kr) { memset(pad_row, 0, MAX_N_STEP * sizeof(uint8_t)); } size_t out_stride = - kai_get_rhs_packed_stride_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); const int32_t input_zero_point = params->lhs_zero_point; const float scale_multiplier = params->scale_multiplier; __asm__ __volatile__( diff --git a/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h similarity index 75% rename from kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h rename to kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h index 76a69241..77b1dbde 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -19,28 +19,28 @@ extern "C" { /// The starting column index must be divisible by `n_step`. /// /// @return Step size for column index. -size_t kai_get_n_step_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); +size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// /// @param[in] n_idx Column index. Must be divisible by `n_step` /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); /// Gets the offset in bytes to the data element in the bias buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_bias_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); /// Gets the offset in bytes to the data element in the scale buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); +size_t kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); /// Gets the offset in bytes to the data element in the packed RHS buffer. /// @@ -49,7 +49,7 @@ size_t kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( /// @param[in] k_chunk_length Number of rows in each chunk. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); /// Gets the size in bytes of the packed RHS buffer. @@ -59,7 +59,7 @@ size_t kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32 /// @param[in] k_chunk_length Number of rows in each chunk. /// /// @return The size in bytes of the packed RHS buffer. -size_t kai_get_rhs_packed_size_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length); /// Runs the RHS packing function for matrix multiplication. @@ -67,10 +67,10 @@ size_t kai_get_rhs_packed_size_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_s /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset /// calculated using the following functions: /// -/// * RHS: @ref kai_get_rhs_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. -/// * Bias: @ref kai_get_bias_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. -/// * Scale: @ref kai_get_scale_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. -/// * Output: @ref kai_get_rhs_packed_offset_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * RHS: @ref kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Scale: @ref kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. /// /// @param[in] n Number of columns of the output matrix. /// @param[in] k_chunk_count Number of chunks. @@ -81,7 +81,7 @@ size_t kai_get_rhs_packed_size_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_s /// @param[in] scale Scale data buffer. /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] params Extra packing parameters. -void kai_run_rhs_igemm_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( +void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params); -- GitLab From 57f6a920108cc5962a8cf48b7df7aa5aaf2f1b45 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 31 Mar 2025 11:15:18 +0000 Subject: [PATCH 04/14] Align LHS wit RHS, and adjust pack order Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...=> kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c} | 43 +++++++++---------- ...=> kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h} | 8 ++-- 4 files changed, 26 insertions(+), 29 deletions(-) rename kai/ukernels/matmul/pack/{kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c => kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c} (92%) rename kai/ukernels/matmul/pack/{kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h => kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h} (86%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7355f2d7..620e1671 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME - kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 68e503fa..c9485567 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -114,7 +114,7 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ - "pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme", + "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", diff --git a/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c similarity index 92% rename from kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c rename to kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c index a138a825..ca794e2b 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c @@ -11,52 +11,49 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. -#include "kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h" +#include "kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h" #include #include #include "kai/kai_common.h" -static const size_t kai_mr = 2; -static const size_t kai_kr = 4; +#define MR 2 +#define KR 4 +#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR) -// Max size is calculated from maximum vector length -// of 256 bytes, divided by sizeof(int8_t) * kai_kr -#define BLOCK_HEIGHT_MAX 64 - -static size_t kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme(void) { - return kai_mr * kai_get_sme_vector_length_u8() / kai_kr; +static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme(void) { + return MR * kai_get_sme_vector_length_u8() / KR; } -size_t kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme(void) { - return kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme(); +size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme(void) { + return kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme(); } -size_t kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme( +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { - KAI_ASSUME(m_idx % kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme() == 0); + KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme() == 0); - return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t); + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); } -size_t kai_get_lhs_packed_size_lhs_igemm_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length) { - const size_t m_end = kai_roundup(m, kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme()); - return kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme(m_end, k_chunk_count, k_chunk_length); +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme()); + return kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme(m_end, k_chunk_count, k_chunk_length); } -void kai_run_lhs_igemm_pack_x8p2vlx4_x8_sme( +void kai_run_lhs_imatmul_pack_x8p2vlx4_x8_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, const void* zero, void* lhs_packed) { KAI_ASSUME(lhs_ptrs != NULL); KAI_ASSUME(lhs_packed != NULL); - const size_t m_step = kai_get_mr_lhs_igemm_pack_x8p2vlx4_x8_sme(); + const size_t m_step = kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme(); const size_t row_offset = 0; const size_t width = k_chunk_length; - KAI_ASSERT(m_step <= BLOCK_HEIGHT_MAX); - const uint8_t* in[BLOCK_HEIGHT_MAX]; + KAI_ASSERT(m_step <= MAX_M_STEP); + const uint8_t* in[MAX_M_STEP]; uint8_t* out_base = lhs_packed; for (size_t i_m = 0; i_m < m; i_m += m_step) { @@ -65,7 +62,7 @@ void kai_run_lhs_igemm_pack_x8p2vlx4_x8_sme( void* out = out_base; for (size_t y = 0; y < height; y += 1) { KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); - in[y] = *(lhs_ptrs + i_k_chunk + (i_m + y) * k_chunk_count); + in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); if (in[y] != zero) { in[y] += lhs_ptr_offset; } @@ -336,7 +333,7 @@ void kai_run_lhs_igemm_pack_x8p2vlx4_x8_sme( "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); - out_base += m_step * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t); + out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); } } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h similarity index 86% rename from kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h rename to kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h index 3b85357b..762f21d6 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_igemm_pack_x8p2vlx4_x8_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h @@ -17,7 +17,7 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return Step size for row index -size_t kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme(void); +size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// @@ -26,7 +26,7 @@ size_t kai_get_m_step_lhs_igemm_pack_x8p2vlx4_x8_sme(void); /// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme( +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); /// Gets the size in bytes of the packed LHS buffer. @@ -36,7 +36,7 @@ size_t kai_get_lhs_packed_offset_lhs_igemm_pack_x8p2vlx4_x8_sme( /// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The size in bytes of the packed LHS buffer. -size_t kai_get_lhs_packed_size_lhs_igemm_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); /// Pack the LHS matrix for use with indirect matrix multiplication /// @@ -47,7 +47,7 @@ size_t kai_get_lhs_packed_size_lhs_igemm_pack_x8p2vlx4_x8_sme(size_t m, size_t k /// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs array, excluding zero pointers. /// @param[in] zero Pointer to a zero element. Used to check for padding pointers in @ref lhs_ptrs. /// @param[out] lhs_packed Packed LHS matrix. -void kai_run_lhs_igemm_pack_x8p2vlx4_x8_sme( +void kai_run_lhs_imatmul_pack_x8p2vlx4_x8_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, const void* zero, void* lhs_packed); -- GitLab From 60209185a9fd8756450cfeab2db16afbbfab9d83 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Tue, 1 Apr 2025 13:31:49 +0000 Subject: [PATCH 05/14] Rename LHS imatmul pack kernel The kernel require the indirection pointers to be laid out in a packed manner, which need to be indicated by the input type name Signed-off-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...> kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c} | 23 ++++++++++--------- ...> kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h} | 8 +++---- 4 files changed, 18 insertions(+), 17 deletions(-) rename kai/ukernels/matmul/pack/{kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c => kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c} (96%) rename kai/ukernels/matmul/pack/{kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h => kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h} (86%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 620e1671..2d39c704 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME - kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index c9485567..8d00d3cb 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -114,7 +114,7 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ - "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme", + "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c similarity index 96% rename from kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c rename to kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index ca794e2b..3e682437 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -11,7 +11,7 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. -#include "kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h" +#include "kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" #include #include @@ -22,33 +22,34 @@ #define KR 4 #define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR) -static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme(void) { +static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { return MR * kai_get_sme_vector_length_u8() / KR; } -size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme(void) { - return kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme(); +size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { + return kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(); } -size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme( +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { - KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme() == 0); + KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme() == 0); return m_idx * k_chunk_count * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); } -size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length) { - const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme()); - return kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme(m_end, k_chunk_count, k_chunk_length); +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme()); + return kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme(m_end, k_chunk_count, k_chunk_length); } -void kai_run_lhs_imatmul_pack_x8p2vlx4_x8_sme( +void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, const void* zero, void* lhs_packed) { KAI_ASSUME(lhs_ptrs != NULL); KAI_ASSUME(lhs_packed != NULL); - const size_t m_step = kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8_sme(); + const size_t m_step = kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(); const size_t row_offset = 0; const size_t width = k_chunk_length; diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h similarity index 86% rename from kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h rename to kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h index 762f21d6..721c07a8 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h @@ -17,7 +17,7 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return Step size for row index -size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme(void); +size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// @@ -26,7 +26,7 @@ size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8_sme(void); /// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme( +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); /// Gets the size in bytes of the packed LHS buffer. @@ -36,7 +36,7 @@ size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8_sme( /// @param[in] k_chunk_length Length of a LHS column split. /// /// @return The size in bytes of the packed LHS buffer. -size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); /// Pack the LHS matrix for use with indirect matrix multiplication /// @@ -47,7 +47,7 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8_sme(size_t m, size_t /// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs array, excluding zero pointers. /// @param[in] zero Pointer to a zero element. Used to check for padding pointers in @ref lhs_ptrs. /// @param[out] lhs_packed Packed LHS matrix. -void kai_run_lhs_imatmul_pack_x8p2vlx4_x8_sme( +void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, const void* zero, void* lhs_packed); -- GitLab From a8f641391964ef0491d8716fbf337abbb0f61db0 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 2 Apr 2025 09:12:36 +0000 Subject: [PATCH 06/14] Add unit testing for IGEMM QAI8 There are some TODOs left in code. They should be addressed, but the change is useful as is. There is also an issue where the RHS packing doesn't seem to be working. This is worked around by always packing entire RHS, and not only the part that is needed for the output portion. This need to be investigated before releasing to the wild Signed-off-by: Emil Ohlsson Reviewed-by: Emil Ohlsson Approved-by: Jakub Sujak --- test/reference/matmul.cpp | 84 +++- test/reference/matmul.hpp | 13 +- test/reference/reorder.cpp | 4 +- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 405 ++++++++++++++++-- 4 files changed, 473 insertions(+), 33 deletions(-) diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index a5fd1e66..f7888e29 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -12,7 +12,6 @@ #include #include "kai/kai_common.h" -#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" @@ -185,6 +184,76 @@ std::vector matmul( return tmp_dst; } +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> +std::vector indirect_matmul_nt_t_quantized( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // + const void* const* lhs_ptrs, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) { + const auto lhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, lhs_quant_width); + const auto rhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, rhs_quant_width); + + std::vector dst(m * n * sizeof(DstData)); + + for (size_t y = 0; y < m; ++y) { + for (size_t x = 0; x < n; ++x) { + DstData acc = 0; + + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) { + const void* lhs_data = *(lhs_ptrs + (y * k_chunk_count + i_k_chunk)); + + for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) { + const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len; + + const auto lhs_data_index = i_k_chunk_len; + const auto lhs_quant_index = (y / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_value = read_array(lhs_data, lhs_data_index); + const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) + : static_cast(1); + const auto lhs_zero_point = lhs_zero_points != nullptr + ? read_array(lhs_zero_points, lhs_quant_index) + : static_cast(0); + + const auto rhs_data_index = x * (k_chunk_count * k_chunk_length) + i; + const auto rhs_quant_index = (x / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_value = read_array(rhs_data, rhs_data_index); + const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) + : static_cast(1); + const auto rhs_zero_point = rhs_zero_points != nullptr + ? read_array(rhs_zero_points, rhs_quant_index) + : static_cast(0); + + acc += (static_cast(lhs_value) - static_cast(lhs_zero_point)) * + static_cast(lhs_scale) * + (static_cast(rhs_value) - static_cast(rhs_zero_point)) * + static_cast(rhs_scale); + } + } + + if (bias_data != nullptr) { + const auto bias_value = read_array(bias_data, x); + const auto bias_scale = bias_scales != nullptr + ? read_array(bias_scales, x / bias_quant_width) + : static_cast(1); + const auto bias_zero_point = bias_zero_points != nullptr + ? read_array(bias_zero_points, x / bias_quant_width) + : static_cast(0); + + acc += (static_cast(bias_value) - static_cast(bias_zero_point)) * + static_cast(bias_scale); + } + + write_array(dst.data(), y * n + x, acc); + } + } + + return dst; +} + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> @@ -207,7 +276,7 @@ std::vector matmul_nt_t_quantized( for (size_t i = 0; i < k; ++i) { const auto lhs_data_index = row * k + i; - const auto lhs_quant_index = row / lhs_quant_height * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; const auto lhs_value = read_array(lhs_data, lhs_data_index); const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) : static_cast(1); @@ -216,7 +285,7 @@ std::vector matmul_nt_t_quantized( : static_cast(0); const auto rhs_data_index = col * k + i; - const auto rhs_quant_index = col / rhs_quant_height * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; const auto rhs_value = read_array(rhs_data, rhs_data_index); const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) : static_cast(1); @@ -259,6 +328,15 @@ matmul_nt_t_quantized +indirect_matmul_nt_t_quantized( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // + const void* const* lhs_ptrs, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 9a8ce9f8..913de1f8 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -156,4 +156,15 @@ std::vector matmul_nt_t_quantized( size_t rhs_quant_width, // const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> +std::vector indirect_matmul_nt_t_quantized( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // + const void* const* lhs_ptrs, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + } // namespace kai::test diff --git a/test/reference/reorder.cpp b/test/reference/reorder.cpp index 564f96f6..61ba67d1 100644 --- a/test/reference/reorder.cpp +++ b/test/reference/reorder.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -46,5 +46,7 @@ std::vector reorder_block( template std::vector reorder_block( const void* src, size_t height, size_t width, size_t block_height, size_t block_width); +template std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width); } // namespace kai::test diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 2ffe5b70..00e83e13 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -9,14 +9,22 @@ #include #include #include +#include +#include +#include #include #include +#include #include +#include +#include "kai/kai_common.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include "test/common/cpu_info.hpp" #include "test/common/matrix_portion.hpp" @@ -37,6 +45,12 @@ namespace kai::test { using Buffer = std::vector; +using IndirectionBuffer = std::vector; + +struct KChunk { + size_t count; + size_t length; +}; struct LhsPackKernel { std::function get_m_step; @@ -49,6 +63,16 @@ struct LhsPackKernel { pack; }; +struct LhsPackIndirectKernel { + std::function get_m_step; + std::function get_packed_lhs_offset; + std::function get_packed_lhs_size; + std::function + pack; +}; + struct RhsPackKernel { std::function get_n_step; std::function get_rhs_offset; @@ -63,6 +87,19 @@ struct RhsPackKernel { pack; }; +struct RhsPackIndirectKernel { + std::function get_n_step; + std::function get_rhs_offset; + std::function get_bias_offset; + std::function get_scale_offset; + std::function get_packed_rhs_offset; + std::function get_packed_rhs_size; + std::function + pack; +}; + struct MatMulKernel { std::function get_m_step; std::function get_n_step; @@ -102,6 +139,18 @@ struct MatMulVariant { MatMulKernel matmul; ///< Matmul kernel interface }; +struct IndirectMatMulVariant { + std::string_view name; ///< Test identification + MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) + MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) + + std::function is_supported; ///< HW support check + + LhsPackIndirectKernel lhs_pack; ///< LHS packing kernel interface + RhsPackIndirectKernel rhs_pack; ///< RHS packing kernel interface + MatMulKernel matmul; ///< Matmul kernel interface +}; + const std::array gemm_variants = { MatMulVariant{ .name = "matmul_qai8_qai8p_qsi8cxp", @@ -146,6 +195,59 @@ const std::array gemm_variants = { }, }; +const std::array indirect_gemm_variants = { + IndirectMatMulVariant{ + .name = "indirect_matmul_qai8_qai8p_qsi8cxp", + .acc_pack{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + .acc_step{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + + .is_supported = cpu_has_sme2, + + .lhs_pack = + LhsPackIndirectKernel{ + .get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + .get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + .get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + .pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + }, + .rhs_pack = + RhsPackIndirectKernel{ + .get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_packed_rhs_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + }, + .matmul = + MatMulKernel{ + .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + }, + }, +}; + const std::array gemv_variants = { MatMulVariant{ .name = "matmul_qai8_qai8_qsi8cxp", @@ -215,6 +317,8 @@ struct TestReference { Buffer lhs_qai8; Buffer lhs_qai8_scales; Buffer lhs_qai8_zero_points; + IndirectionBuffer lhs_qai8_indirect; + Buffer lhs_qai8_indirect_packed; Buffer rhs_qsi8; Buffer rhs_scales; @@ -242,12 +346,26 @@ static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel matmul_clamp_qai8_qai8_ .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, }; +// M, N, K, k_chunk_length, pack.m, pack.n, pack.k +using TestDataId = std::tuple; +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +static std::map g_data; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + /// Generate test reference data -static TestReference get_test_reference(const MatMulShape& shape, const MatMulVariant& variant) { +static const TestReference& get_test_reference( + const MatMulShape& shape, const MatMulShape& pack_shape, size_t k_chunk_len) { // ============================================================ // Generates input and reference output data // ============================================================ + // Attempt to find test data in cache + const TestDataId data_id{shape.m, shape.n, shape.k, k_chunk_len, pack_shape.m, pack_shape.n, pack_shape.k}; + const auto data_it = g_data.find(data_id); + if (data_it != g_data.end()) { + return data_it->second; + } + // Generates the input data in floating-point. const auto lhs_f32 = fill_random(shape.m * shape.k, seed); const auto rhs_f32 = fill_random(shape.k * shape.n, seed); @@ -265,6 +383,23 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); + IndirectionBuffer lhs_qai8_indirect; + + const size_t k_chunk_count = shape.k / k_chunk_len; + assert(k_chunk_count * k_chunk_len == shape.k); + + // Setup an indirection buffer, where each "row" contains `k_chunk_count` + // pointers to chunks of length `k_chunk_len` in the input_buffer + for (size_t m_i = 0; m_i < shape.m; ++m_i) { + for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) { + lhs_qai8_indirect.push_back(&lhs_qai8.at(m_i * shape.k + k_chunk_idx * k_chunk_len)); + } + } + + // Reorder indirection pointers to layout the packing kernel expectes + Buffer lhs_qai8_indirect_packed = reorder_block( + reinterpret_cast(lhs_qai8_indirect.data()), shape.m, k_chunk_count, pack_shape.m, 1); + // Transpose, then quantize symmetrically, then transpose back. This will give one // quantization value for each column const auto rhs_f32_t = transpose(rhs_f32.data(), shape.k, shape.n); @@ -281,9 +416,10 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa // Runs the reference implementation of matmul to produce floating-point result. const auto ref_dst_f32 = - matmul_nt_t_quantized( - shape.m, shape.n, shape.k, // matmul shape - lhs_qai8.data(), &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point + indirect_matmul_nt_t_quantized( + shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape + reinterpret_cast(lhs_qai8_indirect.data()), &lhs_scale, + &lhs_zero_point, // LHS, scaling factor and zero point shape.m, shape.k, // LHS quantization window shape rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors 1, shape.k, // RHS quantization window shape @@ -330,12 +466,12 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa // The reference packing functions cannot be executed earlier // because we need the reference floating-point output first to have // the quantization information. - auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k); + auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, pack_shape.m, pack_shape.k); auto packed_rhs = matmul_pack_rhs_nxk_static_quantized( rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, - variant.acc_pack.n, variant.acc_pack.k); + pack_shape.n, pack_shape.k); - return { + const TestReference& reference = g_data[data_id] = { .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, .qa_lhs = {.scale = lhs_scale, .zero_point = lhs_zero_point}, @@ -344,6 +480,8 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa .lhs_qai8 = std::move(lhs_qai8), .lhs_qai8_scales = std::move(lhs_qai8_scales), .lhs_qai8_zero_points = std::move(lhs_qai8_zero_points), + .lhs_qai8_indirect = std::move(lhs_qai8_indirect), + .lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed), .rhs_qsi8 = std::move(rhs_qsi8), .rhs_scales = std::move(rhs_scales), @@ -355,6 +493,7 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa .packed_lhs = std::move(packed_lhs), .packed_rhs = std::move(packed_rhs), }; + return reference; } /// Test LHS packing @@ -432,6 +571,44 @@ static void test_rhs_pack( ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; } +static void compare_matmul_result( + const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) { + size_t mismatches = 0; + bool printed_row = false; + bool printed_mismatch = false; + for (size_t m_i = 0; m_i < shape.m; ++m_i) { + for (size_t n_i = 0; n_i < shape.n; ++n_i) { + const auto i = m_i * shape.n + n_i; + const auto in_area = m_i >= output_area.start_row() && m_i < output_area.end_row() && + n_i >= output_area.start_col() && n_i < output_area.end_col(); + + const auto imp_value = read_array(actual.data(), i); + const auto ref_value = in_area ? read_array(reference.data(), i) : 0; + const auto error = std::abs(imp_value - ref_value); + const auto threshold = in_area ? 1 : 0; + const bool mismatch = error > threshold; + if (mismatch) { + if (not printed_mismatch) { + std::cout << "Mismatch(es) found:\n"; + printed_mismatch = true; + } + if (not printed_row) { + std::cout << " row=" << m_i; + std::cout << " "; + printed_row = true; + } + std::cout << n_i << ", "; + } + mismatches += static_cast(mismatch); + } + if (printed_row) { + std::cout << "\n"; + } + printed_row = false; + } + ASSERT_EQ(mismatches, 0) << "There are mismatches between reference result actual result"; +} + /// Test MatMul of GEMM/GEMV like kernel static void test_matmul( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { @@ -461,25 +638,12 @@ static void test_matmul( reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), sizeof(int8_t), &imp_main_params); - size_t mismatches = 0; - for (size_t y = 0; y < shape.m; ++y) { - for (size_t x = 0; x < shape.n; ++x) { - const auto i = y * shape.n + x; - const auto in_area = y >= output_area.start_row() && y < output_area.end_row() && - x >= output_area.start_col() && x < output_area.end_col(); - - const auto imp_value = read_array(imp_dst.data(), i); - const auto ref_value = in_area ? read_array(reference.dst_qsi8_clamped.data(), i) : 0; - const auto error = std::abs(imp_value - ref_value); - const auto threshold = in_area ? 1 : 0; - - mismatches += static_cast(error > threshold); - } - } - ASSERT_EQ(mismatches, 0) << "There are mismatched between reference result actual result"; + compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped); } -using ThisTest = testing::TestWithParam>; +using MatMulQuantizedTest = testing::TestWithParam>; +using IndirectMatMulQuantizedTest = + testing::TestWithParam>; static std::string test_description( const MatMulVariant& variant, // @@ -495,14 +659,29 @@ static std::string test_description( return sstream.str(); }; -TEST_P(ThisTest, EndToEnd) { +static std::string test_description( + const IndirectMatMulVariant& variant, // + const MatMulShape& shape, // + const MatrixPortion& portion, size_t k_chunk_len) { + std::stringstream sstream; + sstream << "Method_" << variant.name << "__M_" // + << shape.m << "__N_" << shape.n << "__k_chunk_count_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << "__k_chunk_len_" << k_chunk_len; + return sstream.str(); +}; + +TEST_P(MatMulQuantizedTest, EndToEnd) { const auto& [variant, shape, output_portion] = GetParam(); if (!variant.is_supported()) { GTEST_SKIP() << "CPU features are not supported by current CPU"; } - TestReference reference = get_test_reference(shape, variant); + TestReference reference = get_test_reference(shape, variant.acc_pack, 1); // Check scheduling parameters const auto imp_mr = variant.matmul.get_mr(); @@ -532,8 +711,130 @@ TEST_P(ThisTest, EndToEnd) { test_matmul(shape, variant, matmul_portion, reference); } +namespace imatmul { + +/// Perform LHS IMATMUL packing +static Buffer lhs_pack( + const LhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t m, + const KChunk& k_chunk) { + const void* const* indirection_pointer = + reinterpret_cast(reference.lhs_qai8_indirect_packed.data()); + + // Allocate buffer + const size_t dst_size = variant.get_packed_lhs_size(m, k_chunk.count, k_chunk.length); + Buffer packed(dst_size); + + // Calculate offsets + const size_t input_offset = portion.start_row() * k_chunk.count; + const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + + // TODO: `lhs_offset` is currently not being excercized! + // TODO: Ensure that `zero` pointers are tested + variant.pack( + portion.height(), k_chunk.count, k_chunk.length, // Dimensions + indirection_pointer + input_offset, // Indirection input + 0, // chunk offset + nullptr, // padding pointer + packed.data() + dst_offset); + + return packed; +} + +/// Perform RHS IMATMUL packing +static Buffer rhs_pack( + const RhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t n, + const KChunk& k_chunk) { + // Allocate output buffer + const size_t dst_size = variant.get_packed_rhs_size(n, k_chunk.count, k_chunk.length); + Buffer packed_all(dst_size); + Buffer packed(dst_size); + + // Caluclate effective quantization parameters + const kai_rhs_pack_qsi8cx_params quantization{ + reference.qa_lhs.zero_point, + reference.qa_lhs.scale / reference.qa_dst.scale, + }; + + // Calculate offsets + const size_t rhs_offset = variant.get_rhs_offset(portion.start_col()); + const size_t bias_offset = variant.get_bias_offset(portion.start_col()); + const size_t scale_offset = variant.get_scale_offset(portion.start_col()); + const size_t dst_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + + // Pack + variant.pack( + portion.width(), k_chunk.count, k_chunk.length, + n * sizeof(uint8_t), // Dimensions, row stride + reference.rhs_qsi8.data() + rhs_offset, // RHS matrix + reference.bias_qsi32.data() + bias_offset, // Bias + reference.rhs_scales.data() + scale_offset, // Scales + packed.data() + dst_offset, // Output + &quantization); + + return packed; +} + +/// Calculate the matmul result from IMATMUL kernels +static Buffer matmul( + const MatMulKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs, + const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { + // TODO: This variable is no longer needed when we generate imatmul kernel. + // For now, this is equivalent of passing `k_chunk.count` and `k_chunk.length` + const size_t indirect_k = k_chunk.count * kai_roundup(k_chunk.length, variant.get_kr()); + + // Calculate portion offsets. + size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); + size_t lhs_offset = variant.get_packed_lhs_offset(portion.start_row(), indirect_k); + size_t rhs_offset = variant.get_packed_rhs_offset(portion.start_col(), indirect_k); + + // Allocate output buffer + const size_t dst_size = variant.get_dst_size(shape.m, shape.n); + Buffer dst(dst_size); + + // Calculate geffective uantization parameters + kai_matmul_requantize32_params requantization{ + .min_value = reference.clamp.min, + .max_value = reference.clamp.max, + .output_zero_point = reference.qa_dst.zero_point, + }; + + // Call matmul kernel + variant.matmul( + portion.height(), portion.width(), indirect_k, // Dimensions + packed_lhs.data() + lhs_offset, // LHS + packed_rhs.data() + rhs_offset, // RHS + dst.data() + dst_offset, // DST + shape.n * sizeof(uint8_t), sizeof(uint8_t), &requantization); + + // TODO: Ensure `clamp` is tested + + return dst; +} +} // namespace imatmul + +TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { + /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count + * so instead of inventing a new special kind of shape, simply multiply + * with `k_chunk_len` here */ + const auto& [variant, shape_k_chunk, output_portion, k_chunk_len] = GetParam(); + const KChunk k_chunk{shape_k_chunk.k, k_chunk_len}; + MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length}; + + if (!variant.is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const TestReference& reference = get_test_reference(shape, variant.acc_pack, k_chunk.length); + const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); + + Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk); + Buffer packed_rhs = imatmul::rhs_pack(variant.rhs_pack, portion, reference, shape.n, k_chunk); + Buffer impl_result = imatmul::matmul(variant.matmul, portion, reference, packed_lhs, packed_rhs, shape, k_chunk); + compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped); +} + INSTANTIATE_TEST_SUITE_P( - matmul_clamp_qai8_qai8p_qsi8cxp, ThisTest, + matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest, testing::Combine( testing::ValuesIn(gemm_variants), testing::ValuesIn({ @@ -565,7 +866,7 @@ INSTANTIATE_TEST_SUITE_P( }); INSTANTIATE_TEST_SUITE_P( - matmul_clamp_qai8_qai8_qsi8cxp, ThisTest, + matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest, testing::Combine( testing::ValuesIn(gemv_variants), testing::ValuesIn({ @@ -598,4 +899,52 @@ INSTANTIATE_TEST_SUITE_P( std::get(info.param), // std::get(info.param)); }); + +INSTANTIATE_TEST_SUITE_P( + indirect_matmul_clamp_qai8_qai8p_qsi8cxp, IndirectMatMulQuantizedTest, + testing::Combine( + testing::ValuesIn(indirect_gemm_variants), + testing::ValuesIn({ + // clang-format off + // M, N, k_chunk_count + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 19, 24}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 2, 195, 50}, + MatMulShape{ 3, 6, 6}, + MatMulShape{ 3, 28, 25}, + MatMulShape{ 3, 184,177}, + MatMulShape{ 4, 16, 27}, + MatMulShape{ 5, 136, 23}, + MatMulShape{ 6, 18, 31}, + MatMulShape{ 6, 28, 1}, + MatMulShape{ 6, 29, 24}, + MatMulShape{ 32, 16, 27}, + MatMulShape{ 32, 32, 3}, + MatMulShape{ 33, 29, 24}, + MatMulShape{ 64, 64, 3}, + MatMulShape{ 96, 96, 3}, + MatMulShape{128, 128, 3}, + // clang-format on + }), + testing::ValuesIn({ + // clang-format off + // (Start row , start col , height , width) + MatrixPortion( 0 , 0 , 1 , 1) , // Full matrix. + MatrixPortion( 0 , 0 , 1 , 0.5) , // Left half + MatrixPortion( 0 , 0 , 0.5 , 1) , // Upper half + MatrixPortion( 0 , 0.5 , 1 , 0.5) , // Right half + MatrixPortion( 0.5 , 0 , 0.5 , 1) , // Bottom half + MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3) , // Center ninth + // clang-format on + }), + // k_chunk_len + testing::ValuesIn(std::initializer_list{2, 3, 4, 8, 11, 32})), + [](const auto& info) -> std::string { + return test_description( + std::get(info.param), // + std::get(info.param), // + std::get(info.param), // + std::get(info.param)); + }); } // namespace kai::test -- GitLab From 5762bfde34d800a94727afbbf7bcc236e6a5d16d Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Wed, 2 Apr 2025 11:53:10 +0000 Subject: [PATCH 07/14] Fix RHS igemm packing MAX_N_STEP macro Signed-off-by: Mohammed Suhail Munshi Reviewed-by: Emil Ohlsson Approved-by: Emil Ohlsson --- ...atmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index b8bc2cbb..3db501b6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -18,17 +18,17 @@ #include "kai/kai_common.h" -static const size_t kai_nr = 2; -static const size_t kai_kr = 4; static const size_t kai_num_bytes_input = sizeof(uint8_t); static const size_t kai_num_bytes_output = sizeof(uint8_t); static const size_t kai_num_bytes_bias = sizeof(int32_t); static const size_t kai_num_bytes_scale = sizeof(float32_t); -#define MAX_N_STEP (KAI_SME_VEC_LENGTH_MAX_BYTES / (kai_kr * kai_nr)) +#define NR 2 +#define KR 4 +#define MAX_N_STEP (NR * KAI_SME_VEC_LENGTH_MAX_BYTES / KR) size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { - return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; + return NR * kai_get_sme_vector_length_u8() / KR; } size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { @@ -48,7 +48,7 @@ size_t kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sm static size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t k_chunk_count, size_t k_chunk_length) { return kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * - (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * kai_num_bytes_output + + (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output + kai_num_bytes_scale); } @@ -84,7 +84,7 @@ void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); uint8_t pad_row[MAX_N_STEP]; - if (height % kai_kr) { + if (height % KR) { memset(pad_row, 0, MAX_N_STEP * sizeof(uint8_t)); } -- GitLab From 0a91779ba9d092ecb8393438c39ff2e5abd0b8e1 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 2 Apr 2025 13:16:10 +0000 Subject: [PATCH 08/14] Add IMATMUL version of QAI8 This change adds a IMATMUL version of the QAI8 kernel, as well as changes unit tests to call into this new kernel and adds an interface for this kernel as well Signed-off-by: Emil Ohlsson Reviewed-by: Emil Ohlsson Approved-by: Jakub Sujak --- CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c | 412 ++++++++++++++++++ ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h | 121 +++++ ...atmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 52 +++ .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 80 ++-- 6 files changed, 642 insertions(+), 25 deletions(-) create mode 100644 kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c create mode 100644 kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h create mode 100644 kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d39c704..98a03b02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,6 +195,7 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 8d00d3cb..aa257daa 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -131,6 +131,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS = [ + "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c new file mode 100644 index 00000000..d53b1f0c --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -0,0 +1,412 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 4; + +size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return m_idx * indirect_k * sizeof(int8_t); +} + +static size_t kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t k_chunk_count, size_t k_chunk_length) { + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() * + (sizeof(int32_t) + indirect_k * sizeof(int8_t) + sizeof(float)); +} + +size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(); + return block_idx * + kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + k_chunk_count, k_chunk_length); +} + +size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_row_stride + n_idx * sizeof(int8_t); +} + +size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(int8_t); +} + +void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params) { + typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + int32_t min; + int32_t max; + int32_t result_zero_point; + const int n_0; + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + + args.C = dst; + args.ldcb = dst_row_stride; + args.M = m; + args.N = n; + args.K = indirect_k; + args.min = params->min_value; + args.max = params->max_value; + args.result_zero_point = params->output_zero_point; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w14, [%x[args], %[offsetof_M]]\n" + "mov x13, #0x0\n" + "mov x11, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr w10, [%x[args], %[offsetof_N]]\n" + "ldr x9, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x28, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "mov x27, x9\n" + ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias + "addvl x28, x28, #2\n" + ".inst 0xc09025c0 // addha za0.s, p1/M, p1/M, z14.s\n" + ".inst 0xc09025e1 // addha za1.s, p1/M, p1/M, z15.s\n" + ".inst 0xc09025c2 // addha za2.s, p1/M, p1/M, z14.s\n" + ".inst 0xc09025e3 // addha za3.s, p1/M, p1/M, z15.s\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" + ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" + ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" + ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" + ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" + ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" + ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" + ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" + ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" + ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" + ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" + ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" + ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" + ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" + ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" + ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" + ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" + ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" + ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" + ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" + ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" + ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" + ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" + ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" + ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" + ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" + ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" + ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" + ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" + ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa0400770 // ld1b { z16.b-z17.b }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400788 // ld1b { z8.b-z9.b }, pn9.b/Z, [x28]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0882600 // smopa za0.s, p1/M, p1/M, z16.b, z8.b\n" + ".inst 0xa0892601 // smopa za1.s, p1/M, p1/M, z16.b, z9.b\n" + ".inst 0xa0882622 // smopa za2.s, p1/M, p1/M, z17.b, z8.b\n" + ".inst 0xa0892623 // smopa za3.s, p1/M, p1/M, z17.b, z9.b\n" + "bgt 7b\n" + "8:" // K oddments: End + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x14, x13\n" + "cntw x24\n" + "ld1rw { z27.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x23, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x11, x10\n" + "cmp x25, x24\n" + "ld1rw { z1.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x24, LT\n" + "ld1rw { z0.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_result_zero_point]]\n" + "mov x12, #0x0\n" + "add x26, x26, x11\n" // C += n + "lsr x21, x22, #0x2\n" + "ld1w { z22.s }, p1/Z, [x28]\n" + "madd x26, x13, x23, x26\n" // C += m * ldc + "ld1w { z26.s }, p1/Z, [x28, #1, MUL VL]\n" + "and x20, x22, #0x3\n" + "addvl x28, x28, #2\n" + "cbz x21, 11f\n" + "10:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmul z16.s, z16.s, z22.s\n" + "fmul z17.s, z17.s, z22.s\n" + "add x12, x12, #0x4\n" + "fmul z18.s, z18.s, z22.s\n" + "fmul z19.s, z19.s, z22.s\n" + "cmp x12, x21, LSL #2\n" + "fmul z28.s, z28.s, z26.s\n" + "fmul z29.s, z29.s, z26.s\n" + "fmul z30.s, z30.s, z26.s\n" + "fmul z31.s, z31.s, z26.s\n" + ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf7c // sclamp { z28.s-z31.s }, z27.s, z1.s\n" + "uzp1 z5.h, z16.h, z28.h\n" + "uzp1 z20.h, z17.h, z29.h\n" + "uzp1 z17.h, z18.h, z30.h\n" + "uzp1 z16.h, z19.h, z31.h\n" + "st1b { z5.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z20.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z17.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "blt 10b\n" + "11:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 12f\n" + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmul z4.s, z4.s, z22.s\n" + "fmul z5.s, z5.s, z22.s\n" + "subs x20, x20, #0x1\n" + "fmul z6.s, z6.s, z22.s\n" + "fmul z7.s, z7.s, z22.s\n" + "fmul z12.s, z12.s, z26.s\n" + "fmul z13.s, z13.s, z26.s\n" + "fmul z14.s, z14.s, z26.s\n" + "fmul z15.s, z15.s, z26.s\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" + ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" + "uzp1 z16.h, z4.h, z12.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "subs x20, x20, #0x1\n" + "uzp1 z16.h, z5.h, z13.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "uzp1 z16.h, z6.h, z14.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "12:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 16f\n" + "cmp x25, x24\n" + "mov x12, #0x0\n" + "csel x20, x25, x24, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 14f\n" + "13:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmul z8.s, z8.s, z22.s\n" + "fmul z9.s, z9.s, z22.s\n" + "add x12, x12, #0x4\n" + "fmul z10.s, z10.s, z22.s\n" + "fmul z11.s, z11.s, z22.s\n" + "cmp x12, x21, LSL #2\n" + "fmul z16.s, z16.s, z26.s\n" + "fmul z17.s, z17.s, z26.s\n" + "fmul z18.s, z18.s, z26.s\n" + "fmul z19.s, z19.s, z26.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" + ".inst 0xc1a1cf68 // sclamp { z8.s-z11.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" + "uzp1 z21.h, z8.h, z16.h\n" + "uzp1 z20.h, z9.h, z17.h\n" + "uzp1 z17.h, z10.h, z18.h\n" + "uzp1 z16.h, z11.h, z19.h\n" + "st1b { z21.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z20.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z17.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "blt 13b\n" + "14:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 15f\n" + ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" + ".inst 0xc0860464 // mova { z4.s-z7.s }, za3h.s[x12]\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "fmul z12.s, z12.s, z22.s\n" + "fmul z13.s, z13.s, z22.s\n" + "subs x20, x20, #0x1\n" + "fmul z14.s, z14.s, z22.s\n" + "fmul z15.s, z15.s, z22.s\n" + "fmul z4.s, z4.s, z26.s\n" + "fmul z5.s, z5.s, z26.s\n" + "fmul z6.s, z6.s, z26.s\n" + "fmul z7.s, z7.s, z26.s\n" + ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" + "uzp1 z16.h, z12.h, z4.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "subs x20, x20, #0x1\n" + "uzp1 z16.h, z13.h, z5.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "uzp1 z16.h, z14.h, z6.h\n" + "st1b { z16.h }, p0, [x26]\n" + "15:" // Store to output array: Accumulator row 1 oddments: End + "16:" // Store to output array: End + "incw x11, ALL, MUL #2\n" + "cmp x11, x10\n" + "blt 2b\n" + "incw x13, ALL, MUL #2\n" + "mov x11, #0x0\n" + "cmp x13, x14\n" + "mov x9, x27\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), + [offsetof_KernelArgs_result_zero_point] "I"(offsetof(KernelArgs, result_zero_point)), + [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", + "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", + "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", + "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h new file mode 100644 index 00000000..cc501580 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -0,0 +1,121 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// -# kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme to pack the LHS matrix. +/// -# kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_row_stride Row stride in bytes of the output matrix. + +/// @param[in] params Requantization and clamp parameters. + +void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h new file mode 100644 index 00000000..0a1e23f9 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: imatmul_clamp_qai8_qai8p_qsi8cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_m_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_n_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_mr_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_nr_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_kr_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func_t)( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + +/// Micro-kernel interface +struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_m_step_func_t get_m_step; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_n_step_func_t get_n_step; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_mr_func_t get_mr; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_nr_func_t get_nr; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_kr_func_t get_kr; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t get_dst_offset; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t get_dst_size; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 00e83e13..7168e042 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -19,6 +19,8 @@ #include #include "kai/kai_common.h" +#include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" @@ -117,6 +119,22 @@ struct MatMulKernel { matmul; }; +struct MatMulIndirectKernel { + std::function get_m_step; + std::function get_n_step; + std::function get_mr; + std::function get_nr; + std::function get_kr; + std::function get_packed_lhs_offset; + std::function get_packed_rhs_offset; + std::function get_dst_offset; + std::function get_dst_size; + std::function + matmul; +}; + const static RhsPackKernel rhs_pack = { .get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, .get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, @@ -148,7 +166,7 @@ struct IndirectMatMulVariant { LhsPackIndirectKernel lhs_pack; ///< LHS packing kernel interface RhsPackIndirectKernel rhs_pack; ///< RHS packing kernel interface - MatMulKernel matmul; ///< Matmul kernel interface + MatMulIndirectKernel matmul; ///< Matmul kernel interface }; const std::array gemm_variants = { @@ -230,20 +248,19 @@ const std::array indirect_gemm_variants = { .pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, }, .matmul = - MatMulKernel{ - .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + MatMulIndirectKernel{ + .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_mr = kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_nr = kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_kr = kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }, }, }; @@ -346,6 +363,23 @@ static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel matmul_clamp_qai8_qai8_ .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, }; +/// Make sure that interface matches +static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel + imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface [[maybe_unused]] = { + .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_mr = kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_nr = kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_kr = kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .run_matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, +}; + // M, N, K, k_chunk_length, pack.m, pack.n, pack.k using TestDataId = std::tuple; // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) @@ -776,16 +810,12 @@ static Buffer rhs_pack( /// Calculate the matmul result from IMATMUL kernels static Buffer matmul( - const MatMulKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs, + const MatMulIndirectKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs, const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { - // TODO: This variable is no longer needed when we generate imatmul kernel. - // For now, this is equivalent of passing `k_chunk.count` and `k_chunk.length` - const size_t indirect_k = k_chunk.count * kai_roundup(k_chunk.length, variant.get_kr()); - // Calculate portion offsets. size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); - size_t lhs_offset = variant.get_packed_lhs_offset(portion.start_row(), indirect_k); - size_t rhs_offset = variant.get_packed_rhs_offset(portion.start_col(), indirect_k); + size_t lhs_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + size_t rhs_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); // Allocate output buffer const size_t dst_size = variant.get_dst_size(shape.m, shape.n); @@ -800,11 +830,11 @@ static Buffer matmul( // Call matmul kernel variant.matmul( - portion.height(), portion.width(), indirect_k, // Dimensions - packed_lhs.data() + lhs_offset, // LHS - packed_rhs.data() + rhs_offset, // RHS - dst.data() + dst_offset, // DST - shape.n * sizeof(uint8_t), sizeof(uint8_t), &requantization); + portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions + packed_lhs.data() + lhs_offset, // LHS + packed_rhs.data() + rhs_offset, // RHS + dst.data() + dst_offset, // DST + shape.n * sizeof(uint8_t), &requantization); // TODO: Ensure `clamp` is tested -- GitLab From 346817acf68d03ca129e928f7195ead92c4ad427 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Fri, 4 Apr 2025 07:39:25 +0000 Subject: [PATCH 09/14] Reuse QAI8 test shapes between GEMM, IGEMM and GEMV As reference data is cached, it's possible to run larger amounts of shapes as long as test data is shared. Signed-off-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 111 ++++++++++-------- 1 file changed, 59 insertions(+), 52 deletions(-) diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 7168e042..d71193a9 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -863,24 +863,52 @@ TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped); } +static constexpr std::array shapes{ + // clang-format off + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 16, 4}, + MatMulShape{ 1, 16, 16}, + MatMulShape{ 1, 17, 4}, + MatMulShape{ 1, 19, 24}, + MatMulShape{ 1, 32, 4}, + MatMulShape{ 1, 32, 32}, + MatMulShape{ 1, 33,200}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 1, 64, 4}, + MatMulShape{ 1, 65, 4}, + MatMulShape{ 1, 300, 10}, + MatMulShape{ 1, 512, 4}, + MatMulShape{ 1, 1523, 10}, + MatMulShape{ 2, 195, 50}, + MatMulShape{ 3, 6, 6}, + MatMulShape{ 3, 28, 25}, + MatMulShape{ 3, 184,177}, + MatMulShape{ 4, 16, 27}, + MatMulShape{ 5, 136, 23}, + MatMulShape{ 6, 18, 31}, + MatMulShape{ 6, 28, 1}, + MatMulShape{ 6, 29, 24}, + MatMulShape{ 16, 16, 4}, + MatMulShape{ 20, 30, 40}, + MatMulShape{ 23, 1, 43}, + MatMulShape{ 32, 14, 1}, + MatMulShape{ 32, 16, 27}, + MatMulShape{ 32, 32, 3}, + MatMulShape{ 32, 32, 4}, + MatMulShape{ 33, 29, 24}, + MatMulShape{ 64, 64, 3}, + MatMulShape{ 64, 64, 4}, + MatMulShape{ 96, 96, 3}, + MatMulShape{123, 85, 45}, + MatMulShape{128, 128, 3}, + MatMulShape{130, 130, 6}, + // clang-format on +}; + INSTANTIATE_TEST_SUITE_P( matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest, testing::Combine( - testing::ValuesIn(gemm_variants), - testing::ValuesIn({ - // clang-format off - MatMulShape{ 1, 1, 1}, - MatMulShape{ 1, 49, 21}, - MatMulShape{ 16, 16, 4}, - MatMulShape{ 20, 30, 40}, - MatMulShape{ 23, 1, 43}, - MatMulShape{ 32, 14, 1}, - MatMulShape{ 32, 32, 4}, - MatMulShape{ 64, 64, 4}, - MatMulShape{123, 85, 45}, - MatMulShape{130, 130, 6}, - // clang-format on - }), + testing::ValuesIn(gemm_variants), testing::ValuesIn(shapes), testing::ValuesIn({ // clang-format off MatrixPortion( 0, 0, 1, 1), // Full matrix. @@ -901,18 +929,20 @@ INSTANTIATE_TEST_SUITE_P( testing::ValuesIn(gemv_variants), testing::ValuesIn({ // clang-format off - MatMulShape{1, 1, 1}, - MatMulShape{1, 16, 4}, - MatMulShape{1, 16, 16}, - MatMulShape{1, 17, 4}, - MatMulShape{1, 32, 4}, - MatMulShape{1, 32, 32}, - MatMulShape{1, 33, 200}, - MatMulShape{1, 64, 4}, - MatMulShape{1, 65, 4}, - MatMulShape{1, 300, 10}, - MatMulShape{1, 512, 4}, - MatMulShape{1, 1523, 10}, + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 16, 4}, + MatMulShape{ 1, 16, 16}, + MatMulShape{ 1, 17, 4}, + MatMulShape{ 1, 19, 24}, + MatMulShape{ 1, 32, 4}, + MatMulShape{ 1, 32, 32}, + MatMulShape{ 1, 33,200}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 1, 64, 4}, + MatMulShape{ 1, 65, 4}, + MatMulShape{ 1, 300, 10}, + MatMulShape{ 1, 512, 4}, + MatMulShape{ 1, 1523, 10}, // clang-format on }), testing::ValuesIn({ @@ -933,30 +963,7 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( indirect_matmul_clamp_qai8_qai8p_qsi8cxp, IndirectMatMulQuantizedTest, testing::Combine( - testing::ValuesIn(indirect_gemm_variants), - testing::ValuesIn({ - // clang-format off - // M, N, k_chunk_count - MatMulShape{ 1, 1, 1}, - MatMulShape{ 1, 19, 24}, - MatMulShape{ 1, 49, 21}, - MatMulShape{ 2, 195, 50}, - MatMulShape{ 3, 6, 6}, - MatMulShape{ 3, 28, 25}, - MatMulShape{ 3, 184,177}, - MatMulShape{ 4, 16, 27}, - MatMulShape{ 5, 136, 23}, - MatMulShape{ 6, 18, 31}, - MatMulShape{ 6, 28, 1}, - MatMulShape{ 6, 29, 24}, - MatMulShape{ 32, 16, 27}, - MatMulShape{ 32, 32, 3}, - MatMulShape{ 33, 29, 24}, - MatMulShape{ 64, 64, 3}, - MatMulShape{ 96, 96, 3}, - MatMulShape{128, 128, 3}, - // clang-format on - }), + testing::ValuesIn(indirect_gemm_variants), testing::ValuesIn(shapes), testing::ValuesIn({ // clang-format off // (Start row , start col , height , width) @@ -969,7 +976,7 @@ INSTANTIATE_TEST_SUITE_P( // clang-format on }), // k_chunk_len - testing::ValuesIn(std::initializer_list{2, 3, 4, 8, 11, 32})), + testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 32})), [](const auto& info) -> std::string { return test_description( std::get(info.param), // -- GitLab From 3d3d1ced4849cfcf0d685f79acfa6adef50f13d2 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Fri, 4 Apr 2025 11:05:13 +0000 Subject: [PATCH 10/14] Add changelog for indirect matmul Signed-off-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1247ef29..e5ce678a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,13 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New SME micro-kernels: + - Indirect matrix multiplication (MxN) of QAI8 input and output. + - Packing kernels for LHS and RHS +- New SME2 micro-kernels: + - Indirect matrix multiplication (MxN) of QAI8 input and output. + - Matrix multiplication of packed indirect LHS and packed RHS + ## v1.6.0 - Add CMake installation and `find_package()` support. -- GitLab From 457218dc5a6e056ff2198527d33545900c843029 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 7 Apr 2025 10:34:02 +0000 Subject: [PATCH 11/14] Add IGEMM padding pointer testing For LHS shapes which has more than one row, set the first row of data to be padding. As this further increases the number of different test inputs this change also extends the caching mechanism to use an unordered map to store the generated test data, and uses a single object which encompasses all parameters used to generate test data Signed-off-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- test/common/test_suite.hpp | 16 ++++ test/reference/matmul.cpp | 33 ++++--- test/reference/matmul.hpp | 8 +- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 90 ++++++++++++++----- 4 files changed, 110 insertions(+), 37 deletions(-) diff --git a/test/common/test_suite.hpp b/test/common/test_suite.hpp index cc791037..1bc85177 100644 --- a/test/common/test_suite.hpp +++ b/test/common/test_suite.hpp @@ -76,6 +76,22 @@ struct MatMulShape { size_t m{}; ///< LHS height. size_t n{}; ///< RHS width. size_t k{}; ///< LHS width and RHS height. +private: + friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) { + return // + lhs.m == rhs.m && // + lhs.n == rhs.n && // + lhs.k == rhs.k; + } +}; + +struct HashMatMulShape { + size_t operator()(const kai::test::MatMulShape& shape) const { + return // + (std::hash{}(shape.m) << 0) ^ // + (std::hash{}(shape.n) << 1) ^ // + (std::hash{}(shape.k) << 2); + } }; /// Matrix multiplication test information. diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index f7888e29..e735b23f 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -189,7 +189,8 @@ template < typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_ptrs, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, size_t rhs_quant_width, // @@ -199,27 +200,32 @@ std::vector indirect_matmul_nt_t_quantized( std::vector dst(m * n * sizeof(DstData)); - for (size_t y = 0; y < m; ++y) { - for (size_t x = 0; x < n; ++x) { + for (size_t i_m = 0; i_m < m; ++i_m) { + for (size_t i_n = 0; i_n < n; ++i_n) { DstData acc = 0; for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) { - const void* lhs_data = *(lhs_ptrs + (y * k_chunk_count + i_k_chunk)); + // Calculate the K chunk pointer. Apply offset if this is not padding + const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk; + const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx]; + if (k_chunk_ptr != lhs_padding) { + k_chunk_ptr = reinterpret_cast(reinterpret_cast(k_chunk_ptr) + lhs_offset); + } for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) { const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len; const auto lhs_data_index = i_k_chunk_len; - const auto lhs_quant_index = (y / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; - const auto lhs_value = read_array(lhs_data, lhs_data_index); + const auto lhs_quant_index = (i_m / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_value = read_array(k_chunk_ptr, lhs_data_index); const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) : static_cast(1); const auto lhs_zero_point = lhs_zero_points != nullptr ? read_array(lhs_zero_points, lhs_quant_index) : static_cast(0); - const auto rhs_data_index = x * (k_chunk_count * k_chunk_length) + i; - const auto rhs_quant_index = (x / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_data_index = i_n * (k_chunk_count * k_chunk_length) + i; + const auto rhs_quant_index = (i_n / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; const auto rhs_value = read_array(rhs_data, rhs_data_index); const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) : static_cast(1); @@ -235,19 +241,19 @@ std::vector indirect_matmul_nt_t_quantized( } if (bias_data != nullptr) { - const auto bias_value = read_array(bias_data, x); + const auto bias_value = read_array(bias_data, i_n); const auto bias_scale = bias_scales != nullptr - ? read_array(bias_scales, x / bias_quant_width) + ? read_array(bias_scales, i_n / bias_quant_width) : static_cast(1); const auto bias_zero_point = bias_zero_points != nullptr - ? read_array(bias_zero_points, x / bias_quant_width) + ? read_array(bias_zero_points, i_n / bias_quant_width) : static_cast(0); acc += (static_cast(bias_value) - static_cast(bias_zero_point)) * static_cast(bias_scale); } - write_array(dst.data(), y * n + x, acc); + write_array(dst.data(), i_m * n + i_n, acc); } } @@ -331,7 +337,8 @@ matmul_nt_t_quantized indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_ptrs, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, size_t rhs_quant_width, // diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 913de1f8..8d83e98c 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -122,7 +122,12 @@ std::vector matmul_clamp_nt_t( /// @param[in] m The LHS and output height. /// @param[in] n The RHS height and output width. /// @param[in] k The LHS and RHS width. +/// @param[in] k_chunk_count Number of K chunk pointers per row in lhs_idata matrix +/// @param[in] k_chunk_length Lenght of each K chunk pointed to in lhs_idata matrix /// @param[in] lhs_data The LHS data matrix. +/// @param[in] lhs_idata The indirect LHS data matrix. +/// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers +/// @parma[in] lhs_padding The indirection LHS padding chunk pointer /// @param[in] lhs_scales The LHS quantization scales matrix. /// @param[in] lhs_zero_points The LHS quantization zero points matrix. /// @param[in] lhs_quant_width The LHS quantization block width. @@ -161,7 +166,8 @@ template < typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // - const void* const* lhs_ptrs, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* lhs_zero_points, size_t lhs_quant_height, size_t lhs_quant_width, // const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, size_t rhs_quant_width, // diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index d71193a9..e7ded92c 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "kai/kai_common.h" @@ -336,6 +337,8 @@ struct TestReference { Buffer lhs_qai8_zero_points; IndirectionBuffer lhs_qai8_indirect; Buffer lhs_qai8_indirect_packed; + Buffer lhs_qai8_indirect_padding; + size_t lhs_qai8_indirect_offset; Buffer rhs_qsi8; Buffer rhs_scales; @@ -380,30 +383,59 @@ static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel .run_matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }; -// M, N, K, k_chunk_length, pack.m, pack.n, pack.k -using TestDataId = std::tuple; +static constexpr int8_t padding_value = 0; + +// Functionality for hashing generated test data. +// This is particularly useful for portion testing +// which reuses the exact same data for all portions +struct TestDataId { + MatMulShape shape; + MatMulShape shape_pack; + size_t chunk_len; + bool pad_testing; + +private: + friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { + return // + lhs.shape == rhs.shape && // + lhs.shape_pack == rhs.shape_pack && // + lhs.chunk_len == rhs.chunk_len && // + lhs.pad_testing == rhs.pad_testing; + } +}; + +struct HashTestDataId { + size_t operator()(const TestDataId& id) const { + return // + (HashMatMulShape{}(id.shape) << 0) ^ // + (HashMatMulShape{}(id.shape_pack) << 1) ^ // + (std::hash{}(id.chunk_len) << 2) ^ // + (std::hash{}(id.pad_testing) << 3); + } +}; + // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -static std::map g_data; +static std::unordered_map g_data; // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) /// Generate test reference data -static const TestReference& get_test_reference( - const MatMulShape& shape, const MatMulShape& pack_shape, size_t k_chunk_len) { +static const TestReference& get_test_reference(const TestDataId& test_data_id) { // ============================================================ // Generates input and reference output data // ============================================================ // Attempt to find test data in cache - const TestDataId data_id{shape.m, shape.n, shape.k, k_chunk_len, pack_shape.m, pack_shape.n, pack_shape.k}; - const auto data_it = g_data.find(data_id); + const auto data_it = g_data.find(test_data_id); if (data_it != g_data.end()) { return data_it->second; } + const auto& [shape, pack_shape, k_chunk_len, pad_testing] = test_data_id; + // Generates the input data in floating-point. - const auto lhs_f32 = fill_random(shape.m * shape.k, seed); - const auto rhs_f32 = fill_random(shape.k * shape.n, seed); - const auto bias_f32 = fill_random(shape.n, seed); + Buffer lhs_f32 = fill_random(shape.m * shape.k, seed); + const Buffer rhs_f32 = fill_random(shape.k * shape.n, seed); + const Buffer bias_f32 = fill_random(shape.n, seed); // Quantizes the input data. // * LHS: 8-bit asymmetric per-matrix quantization. @@ -417,18 +449,26 @@ static const TestReference& get_test_reference( const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); - IndirectionBuffer lhs_qai8_indirect; - const size_t k_chunk_count = shape.k / k_chunk_len; assert(k_chunk_count * k_chunk_len == shape.k); // Setup an indirection buffer, where each "row" contains `k_chunk_count` // pointers to chunks of length `k_chunk_len` in the input_buffer + IndirectionBuffer lhs_qai8_indirect(shape.m * k_chunk_count); + Buffer lhs_padding(k_chunk_len, padding_value); for (size_t m_i = 0; m_i < shape.m; ++m_i) { for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) { - lhs_qai8_indirect.push_back(&lhs_qai8.at(m_i * shape.k + k_chunk_idx * k_chunk_len)); + const size_t idx = m_i * k_chunk_count + k_chunk_idx; + if (pad_testing and m_i == 0) { + // Push padding pointers for first row + lhs_qai8_indirect[idx] = lhs_padding.data(); + } else { + uintptr_t offset = m_i * shape.k + k_chunk_idx * k_chunk_len; + lhs_qai8_indirect[idx] = reinterpret_cast(offset); + } } } + const auto indirection_base = reinterpret_cast(lhs_qai8.data()); // Reorder indirection pointers to layout the packing kernel expectes Buffer lhs_qai8_indirect_packed = reorder_block( @@ -449,11 +489,12 @@ static const TestReference& get_test_reference( quantize_symmetric_per_block(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); // Runs the reference implementation of matmul to produce floating-point result. + const void* const* lhs_iptr = reinterpret_cast(lhs_qai8_indirect.data()); const auto ref_dst_f32 = indirect_matmul_nt_t_quantized( - shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape - reinterpret_cast(lhs_qai8_indirect.data()), &lhs_scale, - &lhs_zero_point, // LHS, scaling factor and zero point + shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape + lhs_iptr, indirection_base, lhs_padding.data(), // LHS indirection, offset and padding + &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point shape.m, shape.k, // LHS quantization window shape rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors 1, shape.k, // RHS quantization window shape @@ -505,7 +546,7 @@ static const TestReference& get_test_reference( rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, pack_shape.n, pack_shape.k); - const TestReference& reference = g_data[data_id] = { + const TestReference& reference = g_data[test_data_id] = { .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, .qa_lhs = {.scale = lhs_scale, .zero_point = lhs_zero_point}, @@ -516,6 +557,8 @@ static const TestReference& get_test_reference( .lhs_qai8_zero_points = std::move(lhs_qai8_zero_points), .lhs_qai8_indirect = std::move(lhs_qai8_indirect), .lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed), + .lhs_qai8_indirect_padding = std::move(lhs_padding), + .lhs_qai8_indirect_offset = indirection_base, .rhs_qsi8 = std::move(rhs_qsi8), .rhs_scales = std::move(rhs_scales), @@ -715,7 +758,8 @@ TEST_P(MatMulQuantizedTest, EndToEnd) { GTEST_SKIP() << "CPU features are not supported by current CPU"; } - TestReference reference = get_test_reference(shape, variant.acc_pack, 1); + TestDataId test_data_id{shape, variant.acc_pack, shape.k, false}; + const TestReference& reference = get_test_reference(test_data_id); // Check scheduling parameters const auto imp_mr = variant.matmul.get_mr(); @@ -762,13 +806,11 @@ static Buffer lhs_pack( const size_t input_offset = portion.start_row() * k_chunk.count; const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); - // TODO: `lhs_offset` is currently not being excercized! - // TODO: Ensure that `zero` pointers are tested variant.pack( portion.height(), k_chunk.count, k_chunk.length, // Dimensions indirection_pointer + input_offset, // Indirection input - 0, // chunk offset - nullptr, // padding pointer + reference.lhs_qai8_indirect_offset, // chunk offset + reference.lhs_qai8_indirect_padding.data(), // padding pointer packed.data() + dst_offset); return packed; @@ -854,7 +896,9 @@ TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { GTEST_SKIP() << "CPU features are not supported by current CPU"; } - const TestReference& reference = get_test_reference(shape, variant.acc_pack, k_chunk.length); + // Toggle padding testst when LHS has more than one row + TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1}; + const TestReference& reference = get_test_reference(test_data_id); const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk); -- GitLab From d501e2874f884da3102a3ddd770b1ee2437d8464 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Tue, 8 Apr 2025 08:15:45 +0000 Subject: [PATCH 12/14] Add extra clamp testing Extend the QAI8 testing suite by iterating over clamp rates that will clamp output range to no clamping, clamp 10% of range, and clamp 50% of range Signed-off-by: Emil Ohlsson Reviewed-by: Emil Ohlsson Approved-by: Felix Johnny Thomasmathibalan --- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 68 +++++++++++-------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index e7ded92c..77c697c0 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -305,8 +305,7 @@ const std::array gemv_variants = { }, }; -constexpr uint32_t seed = 0; ///< Random seed used for tests -constexpr float output_clamp_rate = 0.1F; ///< Clamping range in ration of output +constexpr uint32_t seed = 0; ///< Random seed used for tests /// Value range template @@ -393,14 +392,16 @@ struct TestDataId { MatMulShape shape_pack; size_t chunk_len; bool pad_testing; + float clamp_ratio; private: friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { - return // - lhs.shape == rhs.shape && // - lhs.shape_pack == rhs.shape_pack && // - lhs.chunk_len == rhs.chunk_len && // - lhs.pad_testing == rhs.pad_testing; + return // + lhs.shape == rhs.shape && // + lhs.shape_pack == rhs.shape_pack && // + lhs.chunk_len == rhs.chunk_len && // + lhs.pad_testing == rhs.pad_testing && // + lhs.clamp_ratio == rhs.clamp_ratio; } }; @@ -410,7 +411,8 @@ struct HashTestDataId { (HashMatMulShape{}(id.shape) << 0) ^ // (HashMatMulShape{}(id.shape_pack) << 1) ^ // (std::hash{}(id.chunk_len) << 2) ^ // - (std::hash{}(id.pad_testing) << 3); + (std::hash{}(id.pad_testing) << 3) ^ // + (std::hash{}(id.clamp_ratio) << 4); } }; @@ -430,7 +432,7 @@ static const TestReference& get_test_reference(const TestDataId& test_data_id) { return data_it->second; } - const auto& [shape, pack_shape, k_chunk_len, pad_testing] = test_data_id; + const auto& [shape, pack_shape, k_chunk_len, pad_testing, clamp_ratio] = test_data_id; // Generates the input data in floating-point. Buffer lhs_f32 = fill_random(shape.m * shape.k, seed); @@ -520,8 +522,8 @@ static const TestReference& get_test_reference(const TestDataId& test_data_id) { const auto ref_dst_f32_max = reduce_max(ref_dst_f32.data(), shape.m * shape.n); const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min; - const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * output_clamp_rate / 2; - const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * output_clamp_rate / 2; + const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * clamp_ratio / 2; + const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * clamp_ratio / 2; const auto dst_qai8_clamp_min = quantize_asymmetric(ref_dst_f32_clamp_min, dst_scale, dst_zero_point); const auto dst_qai8_clamp_max = @@ -718,28 +720,29 @@ static void test_matmul( compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped); } -using MatMulQuantizedTest = testing::TestWithParam>; +using MatMulQuantizedTest = testing::TestWithParam>; using IndirectMatMulQuantizedTest = - testing::TestWithParam>; + testing::TestWithParam>; static std::string test_description( const MatMulVariant& variant, // const MatMulShape& shape, // - const MatrixPortion& portion) { + const MatrixPortion& portion, float clamp_ratio) { std::stringstream sstream; sstream << "Method_" << variant.name << "__M_" // << shape.m << "__N_" << shape.n << "__K_" << shape.k // << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << "__clamp_ratio_" << static_cast(clamp_ratio * 100); return sstream.str(); }; static std::string test_description( const IndirectMatMulVariant& variant, // const MatMulShape& shape, // - const MatrixPortion& portion, size_t k_chunk_len) { + const MatrixPortion& portion, size_t k_chunk_len, float clamp_ratio) { std::stringstream sstream; sstream << "Method_" << variant.name << "__M_" // << shape.m << "__N_" << shape.n << "__k_chunk_count_" << shape.k // @@ -747,18 +750,19 @@ static std::string test_description( << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // << "__PortionHeight_" << static_cast(portion.height() * 1000) // << "__PortionWidth_" << static_cast(portion.width() * 1000) // - << "__k_chunk_len_" << k_chunk_len; + << "__k_chunk_len_" << k_chunk_len // + << "__clamp_ratio_" << static_cast(clamp_ratio * 100); return sstream.str(); }; TEST_P(MatMulQuantizedTest, EndToEnd) { - const auto& [variant, shape, output_portion] = GetParam(); + const auto& [variant, shape, output_portion, clamp_ratio] = GetParam(); if (!variant.is_supported()) { GTEST_SKIP() << "CPU features are not supported by current CPU"; } - TestDataId test_data_id{shape, variant.acc_pack, shape.k, false}; + TestDataId test_data_id{shape, variant.acc_pack, shape.k, false, clamp_ratio}; const TestReference& reference = get_test_reference(test_data_id); // Check scheduling parameters @@ -878,8 +882,6 @@ static Buffer matmul( dst.data() + dst_offset, // DST shape.n * sizeof(uint8_t), &requantization); - // TODO: Ensure `clamp` is tested - return dst; } } // namespace imatmul @@ -888,7 +890,7 @@ TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count * so instead of inventing a new special kind of shape, simply multiply * with `k_chunk_len` here */ - const auto& [variant, shape_k_chunk, output_portion, k_chunk_len] = GetParam(); + const auto& [variant, shape_k_chunk, output_portion, k_chunk_len, clamp_ratio] = GetParam(); const KChunk k_chunk{shape_k_chunk.k, k_chunk_len}; MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length}; @@ -897,7 +899,7 @@ TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { } // Toggle padding testst when LHS has more than one row - TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1}; + TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1, clamp_ratio}; const TestReference& reference = get_test_reference(test_data_id); const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); @@ -959,12 +961,14 @@ INSTANTIATE_TEST_SUITE_P( MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner. MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. // clang-format on - })), + }), + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), [](const auto& info) -> std::string { return test_description( std::get(info.param), // std::get(info.param), // - std::get(info.param)); + std::get(info.param), // + std::get(info.param)); }); INSTANTIATE_TEST_SUITE_P( @@ -996,12 +1000,15 @@ INSTANTIATE_TEST_SUITE_P( MatrixPortion(0, 0, 1, .5), // Left half MatrixPortion(0, .25, 1, .5) // Middle half // clang-format on - })), + }), + // Clamp range + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), [](const auto& info) -> std::string { return test_description( std::get(info.param), // std::get(info.param), // - std::get(info.param)); + std::get(info.param), // + std::get(info.param)); }); INSTANTIATE_TEST_SUITE_P( @@ -1020,12 +1027,15 @@ INSTANTIATE_TEST_SUITE_P( // clang-format on }), // k_chunk_len - testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 32})), + testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 32}), + // Clamp range + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), [](const auto& info) -> std::string { return test_description( std::get(info.param), // std::get(info.param), // std::get(info.param), // - std::get(info.param)); + std::get(info.param), // + std::get(info.param)); }); } // namespace kai::test -- GitLab From 517502011c7f7647d1cfd958e50b4db28df2de95 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 9 Apr 2025 09:20:41 +0000 Subject: [PATCH 13/14] Address review comments from feature review * Rename `matmul` in imatmul interface to `imatmul` * rename `zero` argument in lhs pack to `pad_ptr` * Clarify `k_chunk_length` to mean "in bytes" Signed-off-by: Emil Ohlsson Approved-by: Jakub Sujak --- ...matmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 4 ++-- .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c | 4 ++-- .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h | 18 +++++++++++------- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 8 ++++---- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index 0a1e23f9..01b5ca95 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -29,7 +29,7 @@ typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t)( typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); /// Micro-kernel core function ("run" method) -typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t)( +typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); @@ -44,7 +44,7 @@ struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t get_dst_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t get_dst_size; - kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t run_matmul; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t run_imatmul; }; #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index 3e682437..25a48afc 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -45,7 +45,7 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme( void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, - const void* zero, void* lhs_packed) { + const void* pad_ptr, void* lhs_packed) { KAI_ASSUME(lhs_ptrs != NULL); KAI_ASSUME(lhs_packed != NULL); @@ -64,7 +64,7 @@ void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( for (size_t y = 0; y < height; y += 1) { KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); - if (in[y] != zero) { + if (in[y] != pad_ptr) { in[y] += lhs_ptr_offset; } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h index 721c07a8..7136d837 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h @@ -23,7 +23,7 @@ size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void); /// /// @param[in] m_idx Row index in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( @@ -33,7 +33,7 @@ size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( /// /// @param[in] m Number of rows in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. /// /// @return The size in bytes of the packed LHS buffer. size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); @@ -42,14 +42,18 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_ /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split. -/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of `m * k_chunk_count` pointers. -/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs array, excluding zero pointers. -/// @param[in] zero Pointer to a zero element. Used to check for padding pointers in @ref lhs_ptrs. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of +/// t `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs +/// array, excluding zero pointers. +/// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is +/// not applied to this pointer when used in @ref lhs_ptrs. This can +/// be NULL if there is no padding used @ref lhs_ptrs /// @param[out] lhs_packed Packed LHS matrix. void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, - const void* zero, void* lhs_packed); + const void* pad_ptr, void* lhs_packed); #ifdef __cplusplus } // extern "C" diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 77c697c0..ff01ab4a 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -133,7 +133,7 @@ struct MatMulIndirectKernel { std::function - matmul; + imatmul; }; const static RhsPackKernel rhs_pack = { @@ -261,7 +261,7 @@ const std::array indirect_gemm_variants = { kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }, }, }; @@ -379,7 +379,7 @@ static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .run_matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }; static constexpr int8_t padding_value = 0; @@ -875,7 +875,7 @@ static Buffer matmul( }; // Call matmul kernel - variant.matmul( + variant.imatmul( portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions packed_lhs.data() + lhs_offset, // LHS packed_rhs.data() + rhs_offset, // RHS -- GitLab From 90c21e242d789542cf725c884b6d4f11f7376a47 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 9 Apr 2025 10:42:13 +0000 Subject: [PATCH 14/14] Remove imatmul MR, NR, KR, and change mismatch dump Signed-off-by: Emil Ohlsson Approved-by: Jakub Sujak --- ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c | 12 ---------- ...8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h | 21 ---------------- ...atmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 6 ----- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 24 ++++--------------- 4 files changed, 5 insertions(+), 58 deletions(-) diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c index d53b1f0c..b2eeb5ef 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -30,18 +30,6 @@ size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_ return kai_nr * kai_get_sme_vector_length_u32(); } -size_t kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); -} - -size_t kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); -} - -size_t kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { - return kai_kr; -} - size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h index cc501580..21c7b526 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -32,27 +32,6 @@ size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_ /// @return The n step value. size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); -/// Gets mr value. -/// -/// This is the packing parameter which must be used to pack the LHS matrix. -/// -/// @return The mr value. -size_t kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); - -/// Gets nr value. -/// -/// This is the packing parameter which must be used to pack the RHS matrix. -/// -/// @return The nr value. -size_t kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); - -/// Gets kr value. -/// -/// This is the packing parameter which must be used to pack the LHS and RHS matrix. -/// -/// @return The kr value. -size_t kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); - /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// /// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index 01b5ca95..84ca66b1 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -17,9 +17,6 @@ extern "C" { /// Micro-kernel helper functions ("get" methods) typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_m_step_func_t)(void); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_n_step_func_t)(void); -typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_mr_func_t)(void); -typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_nr_func_t)(void); -typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_kr_func_t)(void); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func_t)( size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t)( @@ -37,9 +34,6 @@ typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t)( struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_m_step_func_t get_m_step; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_n_step_func_t get_n_step; - kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_mr_func_t get_mr; - kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_nr_func_t get_nr; - kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_kr_func_t get_kr; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t get_dst_offset; diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index ff01ab4a..d1529ea0 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -123,9 +123,6 @@ struct MatMulKernel { struct MatMulIndirectKernel { std::function get_m_step; std::function get_n_step; - std::function get_mr; - std::function get_nr; - std::function get_kr; std::function get_packed_lhs_offset; std::function get_packed_rhs_offset; std::function get_dst_offset; @@ -252,9 +249,6 @@ const std::array indirect_gemm_variants = { MatMulIndirectKernel{ .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_mr = kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_nr = kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_kr = kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_packed_lhs_offset = kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_packed_rhs_offset = @@ -370,9 +364,6 @@ static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface [[maybe_unused]] = { .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_mr = kai_get_mr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_nr = kai_get_nr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_kr = kai_get_kr_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_rhs_packed_offset = @@ -654,7 +645,7 @@ static void compare_matmul_result( const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) { size_t mismatches = 0; bool printed_row = false; - bool printed_mismatch = false; + std::ostringstream sstream; for (size_t m_i = 0; m_i < shape.m; ++m_i) { for (size_t n_i = 0; n_i < shape.n; ++n_i) { const auto i = m_i * shape.n + n_i; @@ -667,25 +658,20 @@ static void compare_matmul_result( const auto threshold = in_area ? 1 : 0; const bool mismatch = error > threshold; if (mismatch) { - if (not printed_mismatch) { - std::cout << "Mismatch(es) found:\n"; - printed_mismatch = true; - } if (not printed_row) { - std::cout << " row=" << m_i; - std::cout << " "; + sstream << " row=" << m_i << ", columns: "; printed_row = true; } - std::cout << n_i << ", "; + sstream << n_i << ", "; } mismatches += static_cast(mismatch); } if (printed_row) { - std::cout << "\n"; + sstream << "\n"; } printed_row = false; } - ASSERT_EQ(mismatches, 0) << "There are mismatches between reference result actual result"; + ASSERT_EQ(mismatches, 0) << "Mismatches between reference result and actual result:\n" << sstream.str(); } /// Test MatMul of GEMM/GEMV like kernel -- GitLab