diff --git a/CHANGELOG.md b/CHANGELOG.md index 1247ef29981e6b313615168eca23bec690e20b09..e5ce678a9f4ff75a371f0ed8a9f516bf7a717cef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,13 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New SME micro-kernels: + - Indirect matrix multiplication (MxN) of QAI8 input and output. + - Packing kernels for LHS and RHS +- New SME2 micro-kernels: + - Indirect matrix multiplication (MxN) of QAI8 input and output. + - Matrix multiplication of packed indirect LHS and packed RHS + ## v1.6.0 - Add CMake installation and `find_package()` support. diff --git a/CMakeLists.txt b/CMakeLists.txt index 53ce829f60748d381793219e160eaae6fcf4af00..b6ccd8306cee5dd5912d507e2a9ab98901396e01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,9 +181,11 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME + kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c + kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -193,6 +195,7 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c diff --git a/kai/kai_common.h b/kai/kai_common.h index c1cb1eca10d750308eb09945d4d71301a2977bb5..47e00bfa67a67026eebef4e78603164df63b7e6c 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -54,6 +54,9 @@ extern "C" { #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) +/// Largest supported SME vector length in bytes +#define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum) + /// Gets the version of the project in the Major.Minor.Patch semantic versioning format. /// /// @return Project version as a string literal. diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 3796d88d150d8ee230cdee2421b192837c51f31f..aa257daa96dc4635ea6f02ecae2eb359175a0d41 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -114,10 +114,12 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ + "pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_lhs_pack_f32p2vlx1_f32_sme", "pack/kai_lhs_pack_x16p2vlx2_x16_sme", "pack/kai_lhs_pack_x8p2vlx4_x8_sme", + "pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme", "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", @@ -129,6 +131,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS = [ + "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..b2eeb5efcc17aeb745cc6e71c9e5104b70607bfe --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -0,0 +1,400 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 4; + +size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return m_idx * indirect_k * sizeof(int8_t); +} + +static size_t kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t k_chunk_count, size_t k_chunk_length) { + const size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + return kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() * + (sizeof(int32_t) + indirect_k * sizeof(int8_t) + sizeof(float)); +} + +size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(); + return block_idx * + kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + k_chunk_count, k_chunk_length); +} + +size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0); + + return m_idx * dst_row_stride + n_idx * sizeof(int8_t); +} + +size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(int8_t); +} + +void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params) { + typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + int32_t min; + int32_t max; + int32_t result_zero_point; + const int n_0; + void* accumulator_buffer; + uint64_t flags; + } KernelArgs; + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + + size_t indirect_k = k_chunk_count * kai_roundup(k_chunk_length, kai_kr); + + args.C = dst; + args.ldcb = dst_row_stride; + args.M = m; + args.N = n; + args.K = indirect_k; + args.min = params->min_value; + args.max = params->max_value; + args.result_zero_point = params->output_zero_point; + + args.accumulator_buffer = NULL; + args.flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ldr w14, [%x[args], %[offsetof_M]]\n" + "mov x13, #0x0\n" + "mov x11, #0x0\n" + "ptrue p1.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr w10, [%x[args], %[offsetof_N]]\n" + "ldr x9, [%x[args], %[offsetof_A]]\n" + "1:" // M loop + "ldr x28, [%x[args], %[offsetof_B]]\n" + "2:" // N loop + ".inst 0x25aa4570 // whilelt pn8.s, x11, x10, VLx2\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "mov x27, x9\n" + ".inst 0xa040438e // ld1w { z14.s-z15.s }, p8/Z, [x28]\n" // Load bias + "addvl x28, x28, #2\n" + ".inst 0xc09025c0 // addha za0.s, p1/M, p1/M, z14.s\n" + ".inst 0xc09025e1 // addha za1.s, p1/M, p1/M, z15.s\n" + ".inst 0xc09025c2 // addha za2.s, p1/M, p1/M, z14.s\n" + ".inst 0xc09025e3 // addha za3.s, p1/M, p1/M, z15.s\n" + "ldr x20, [%x[args], %[offsetof_K]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 6f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" + ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" + ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "ble 5f\n" + "4:" // K loop + ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" + ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" + ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" + ".inst 0xa0400762 // ld1b { z2.b-z3.b }, pn9.b/Z, [x27]\n" + ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" + ".inst 0xa1400780 // ld1b { z0.b, z8.b }, pn9.b/Z, [x28]\n" + ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" + ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" + ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" + ".inst 0xa0410772 // ld1b { z18.b-z19.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" + ".inst 0xa0410794 // ld1b { z20.b-z21.b }, pn9.b/Z, [x28, #0x2, MUL VL]\n" + ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" + ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" + ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" + ".inst 0xa042077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa0420796 // ld1b { z22.b-z23.b }, pn9.b/Z, [x28, #0x4, MUL VL]\n" + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" + ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" + ".inst 0xa0430778 // ld1b { z24.b-z25.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa0430784 // ld1b { z4.b-z5.b }, pn9.b/Z, [x28, #0x6, MUL VL]\n" + "addvl x28, x28, #8\n" + "bgt 4b\n" + "5:" // K loop tail + ".inst 0xa0802440 // smopa za0.s, p1/M, p1/M, z2.b, z0.b\n" + ".inst 0xa0882441 // smopa za1.s, p1/M, p1/M, z2.b, z8.b\n" + ".inst 0xa0802462 // smopa za2.s, p1/M, p1/M, z3.b, z0.b\n" + ".inst 0xa0882463 // smopa za3.s, p1/M, p1/M, z3.b, z8.b\n" + ".inst 0xa0942640 // smopa za0.s, p1/M, p1/M, z18.b, z20.b\n" + ".inst 0xa0952641 // smopa za1.s, p1/M, p1/M, z18.b, z21.b\n" + ".inst 0xa0942662 // smopa za2.s, p1/M, p1/M, z19.b, z20.b\n" + ".inst 0xa0952663 // smopa za3.s, p1/M, p1/M, z19.b, z21.b\n" + ".inst 0xa0962740 // smopa za0.s, p1/M, p1/M, z26.b, z22.b\n" + ".inst 0xa0972741 // smopa za1.s, p1/M, p1/M, z26.b, z23.b\n" + ".inst 0xa0962762 // smopa za2.s, p1/M, p1/M, z27.b, z22.b\n" + ".inst 0xa0972763 // smopa za3.s, p1/M, p1/M, z27.b, z23.b\n" + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0852701 // smopa za1.s, p1/M, p1/M, z24.b, z5.b\n" + ".inst 0xa0842722 // smopa za2.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0852723 // smopa za3.s, p1/M, p1/M, z25.b, z5.b\n" + "6:" // K oddments + "cbz x20, 8f\n" + "7:" // K oddments: Loop + ".inst 0xa0400770 // ld1b { z16.b-z17.b }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400788 // ld1b { z8.b-z9.b }, pn9.b/Z, [x28]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0882600 // smopa za0.s, p1/M, p1/M, z16.b, z8.b\n" + ".inst 0xa0892601 // smopa za1.s, p1/M, p1/M, z16.b, z9.b\n" + ".inst 0xa0882622 // smopa za2.s, p1/M, p1/M, z17.b, z8.b\n" + ".inst 0xa0892623 // smopa za3.s, p1/M, p1/M, z17.b, z9.b\n" + "bgt 7b\n" + "8:" // K oddments: End + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x14, x13\n" + "cntw x24\n" + "ld1rw { z27.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "ldr x23, [%x[args], %[offsetof_ldcb]]\n" + "whilelt p0.h, x11, x10\n" + "cmp x25, x24\n" + "ld1rw { z1.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x24, LT\n" + "ld1rw { z0.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_result_zero_point]]\n" + "mov x12, #0x0\n" + "add x26, x26, x11\n" // C += n + "lsr x21, x22, #0x2\n" + "ld1w { z22.s }, p1/Z, [x28]\n" + "madd x26, x13, x23, x26\n" // C += m * ldc + "ld1w { z26.s }, p1/Z, [x28, #1, MUL VL]\n" + "and x20, x22, #0x3\n" + "addvl x28, x28, #2\n" + "cbz x21, 11f\n" + "10:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmul z16.s, z16.s, z22.s\n" + "fmul z17.s, z17.s, z22.s\n" + "add x12, x12, #0x4\n" + "fmul z18.s, z18.s, z22.s\n" + "fmul z19.s, z19.s, z22.s\n" + "cmp x12, x21, LSL #2\n" + "fmul z28.s, z28.s, z26.s\n" + "fmul z29.s, z29.s, z26.s\n" + "fmul z30.s, z30.s, z26.s\n" + "fmul z31.s, z31.s, z26.s\n" + ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf7c // sclamp { z28.s-z31.s }, z27.s, z1.s\n" + "uzp1 z5.h, z16.h, z28.h\n" + "uzp1 z20.h, z17.h, z29.h\n" + "uzp1 z17.h, z18.h, z30.h\n" + "uzp1 z16.h, z19.h, z31.h\n" + "st1b { z5.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z20.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z17.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "blt 10b\n" + "11:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 12f\n" + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmul z4.s, z4.s, z22.s\n" + "fmul z5.s, z5.s, z22.s\n" + "subs x20, x20, #0x1\n" + "fmul z6.s, z6.s, z22.s\n" + "fmul z7.s, z7.s, z22.s\n" + "fmul z12.s, z12.s, z26.s\n" + "fmul z13.s, z13.s, z26.s\n" + "fmul z14.s, z14.s, z26.s\n" + "fmul z15.s, z15.s, z26.s\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" + ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" + "uzp1 z16.h, z4.h, z12.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "subs x20, x20, #0x1\n" + "uzp1 z16.h, z5.h, z13.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 12f\n" + "uzp1 z16.h, z6.h, z14.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "12:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 16f\n" + "cmp x25, x24\n" + "mov x12, #0x0\n" + "csel x20, x25, x24, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 14f\n" + "13:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmul z8.s, z8.s, z22.s\n" + "fmul z9.s, z9.s, z22.s\n" + "add x12, x12, #0x4\n" + "fmul z10.s, z10.s, z22.s\n" + "fmul z11.s, z11.s, z22.s\n" + "cmp x12, x21, LSL #2\n" + "fmul z16.s, z16.s, z26.s\n" + "fmul z17.s, z17.s, z26.s\n" + "fmul z18.s, z18.s, z26.s\n" + "fmul z19.s, z19.s, z26.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1b8e210 // frintn { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc131e210 // fcvtzs { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc1a0ab10 // add { z16.s-z19.s }, { z16.s-z19.s }, z0.s\n" + ".inst 0xc1a1cf68 // sclamp { z8.s-z11.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf70 // sclamp { z16.s-z19.s }, z27.s, z1.s\n" + "uzp1 z21.h, z8.h, z16.h\n" + "uzp1 z20.h, z9.h, z17.h\n" + "uzp1 z17.h, z10.h, z18.h\n" + "uzp1 z16.h, z11.h, z19.h\n" + "st1b { z21.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z20.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z17.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "blt 13b\n" + "14:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 15f\n" + ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" + ".inst 0xc0860464 // mova { z4.s-z7.s }, za3h.s[x12]\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "fmul z12.s, z12.s, z22.s\n" + "fmul z13.s, z13.s, z22.s\n" + "subs x20, x20, #0x1\n" + "fmul z14.s, z14.s, z22.s\n" + "fmul z15.s, z15.s, z22.s\n" + "fmul z4.s, z4.s, z26.s\n" + "fmul z5.s, z5.s, z26.s\n" + "fmul z6.s, z6.s, z26.s\n" + "fmul z7.s, z7.s, z26.s\n" + ".inst 0xc1b8e18c // frintn { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc131e18c // fcvtzs { z12.s-z15.s }, { z12.s-z15.s }\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab0c // add { z12.s-z15.s }, { z12.s-z15.s }, z0.s\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc1a1cf6c // sclamp { z12.s-z15.s }, z27.s, z1.s\n" + ".inst 0xc1a1cf64 // sclamp { z4.s-z7.s }, z27.s, z1.s\n" + "uzp1 z16.h, z12.h, z4.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "subs x20, x20, #0x1\n" + "uzp1 z16.h, z13.h, z5.h\n" + "st1b { z16.h }, p0, [x26]\n" + "add x26, x26, x23\n" + "beq 15f\n" + "uzp1 z16.h, z14.h, z6.h\n" + "st1b { z16.h }, p0, [x26]\n" + "15:" // Store to output array: Accumulator row 1 oddments: End + "16:" // Store to output array: End + "incw x11, ALL, MUL #2\n" + "cmp x11, x10\n" + "blt 2b\n" + "incw x13, ALL, MUL #2\n" + "mov x11, #0x0\n" + "cmp x13, x14\n" + "mov x9, x27\n" + "blt 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r"(&args), [offsetof_A] "I"(offsetof(KernelArgs, A)), [offsetof_B] "I"(offsetof(KernelArgs, B)), + [offsetof_C] "I"(offsetof(KernelArgs, C)), [offsetof_K] "I"(offsetof(KernelArgs, K)), + [offsetof_KernelArgs_max] "I"(offsetof(KernelArgs, max)), + [offsetof_KernelArgs_min] "I"(offsetof(KernelArgs, min)), + [offsetof_KernelArgs_result_zero_point] "I"(offsetof(KernelArgs, result_zero_point)), + [offsetof_M] "I"(offsetof(KernelArgs, M)), [offsetof_N] "I"(offsetof(KernelArgs, N)), + [offsetof_ldcb] "I"(offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", + "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", + "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", + "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..21c7b52681a113a73a57e8350d212162e1fbd9e1 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h @@ -0,0 +1,100 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// -# kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme to pack the LHS matrix. +/// -# kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_row_stride. Distance between start of two rows in the output buffer. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_row_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length of a LHS column split +/// @param[in] packed_lhs Packed LHS matrix buffer. +/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_row_stride Row stride in bytes of the output matrix. + +/// @param[in] params Requantization and clamp parameters. + +void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..84ca66b1bd14a73af0e5b263592267e9dac87ce8 --- /dev/null +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -0,0 +1,46 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: imatmul_clamp_qai8_qai8p_qsi8cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_m_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_n_step_func_t)(void); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func_t)( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t)( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); + +/// Micro-kernel interface +struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_m_step_func_t get_m_step; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_n_step_func_t get_n_step; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t get_dst_offset; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t get_dst_size; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t run_imatmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..25a48afc6d939aede29405f2ae2b0cf30ad56181 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -0,0 +1,342 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" + +#include +#include + +#include "kai/kai_common.h" + +#define MR 2 +#define KR 4 +#define MAX_M_STEP (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR) + +static size_t kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { + return MR * kai_get_sme_vector_length_u8() / KR; +} + +size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void) { + return kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(); +} + +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(m_idx % kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme() == 0); + + return m_idx * k_chunk_count * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); +} + +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length) { + const size_t m_end = kai_roundup(m, kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme()); + return kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme(m_end, k_chunk_count, k_chunk_length); +} + +void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed) { + KAI_ASSUME(lhs_ptrs != NULL); + KAI_ASSUME(lhs_packed != NULL); + + const size_t m_step = kai_get_mr_lhs_imatmul_pack_x8p2vlx4_x8p_sme(); + const size_t row_offset = 0; + const size_t width = k_chunk_length; + + KAI_ASSERT(m_step <= MAX_M_STEP); + const uint8_t* in[MAX_M_STEP]; + + uint8_t* out_base = lhs_packed; + for (size_t i_m = 0; i_m < m; i_m += m_step) { + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; i_k_chunk += 1) { + const size_t height = KAI_MIN(m - i_m, m_step); + void* out = out_base; + for (size_t y = 0; y < height; y += 1) { + KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); + in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); + if (in[y] != pad_ptr) { + in[y] += lhs_ptr_offset; + } + } + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x23, %x[width]\n" + "mov x21, %x[width]\n" + "cntb x20\n" + "incb x23\n" + "sub x7, x20, #0x1\n" + "cntw x8\n" + "sub x23, x23, #0x1\n" + "ands x7, x21, x7\n" + "udiv x23, x23, x20\n" // n_passes = ceildiv(width, VL) + "csel x7, x7, x20, NE\n" + "lsl x22, %x[height], #0x1\n" // height * 2 + "lsl x21, x8, #0x1\n" + "sub x20, x23, #0x1\n" + "add x7, x7, #0x3\n" + "sub x17, x8, #0x2\n" + "whilelt p9.b, XZR, x22\n" + "whilelt p8.b, x21, x22\n" + "mov x16, #0x0\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + "cntw x9, ALL, MUL #2\n" + "cntw x28, ALL, MUL #3\n" + "ldr x27, [x11, #0x0]\n" + "lsr x20, x20, #0x1\n" // n_loops = (n_passes - 1) / 2 + "and x26, x23, #0x1\n" // odd_tail = bool(n_passes & 0x1) + "ldr x25, [x10, #0x0]\n" + "lsr x7, x7, #0x2\n" + "ptrue p11.s\n" + "ldr x24, [x11, #0x8]\n" + "zip1 p10.b, p9.b, p8.b\n" + "mov x23, %x[row_offset]\n" + "ldr x21, [x10, #0x8]\n" + "mov x22, %x[out]\n" + "whilelt p9.b, x16, %x[width]\n" + "whilelt p8.b, x16, %x[width]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "mov x12, #0x0\n" + "cbz x17, 2f\n" + "1:" // K loop: Charge: Loop + ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" + ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" + ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" + ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" + ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" + "add x12, x12, #0x8\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "cmp x12, x17, LSL #2\n" + "blt 1b\n" + "2:" // K loop: Charge: End + ".inst 0x25246143 // psel p3.b, p8.b/Z, p10.b[w12]\n" + ".inst 0x252c6142 // psel p2.b, p8.b/Z, p10.b[w12, #1]\n" + ".inst 0x25646141 // psel p1.b, p8.b/Z, p10.b[w12, #4]\n" + ".inst 0x256c6140 // psel p0.b, p8.b/Z, p10.b[w12, #5]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0170f60 // ld1b { za0h.b[x12] }, p3/Z, [x27, x23]\n" + "ldr x27, [x11, #0x0]\n" + "incb x16\n" + ".inst 0xe0170b21 // ld1b { za0h.b[x12, #1] }, p2/Z, [x25, x23]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0170704 // ld1b { za0h.b[x12, #4] }, p1/Z, [x24, x23]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01702a5 // ld1b { za0h.b[x12, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + "add x10, x10, #0x10\n" + "incb x23\n" + "cbz x20, 8f\n" + "mov x20, x20\n" + "3:" // K loop: Main loop + "whilelt p8.b, x16, %x[width]\n" + "mov x15, #0x0\n" + "mov x14, #0x0\n" + "cbz x17, 5f\n" + "4:" // K loop: Main loop: First: Loop + ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" + ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" + ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" + ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" + ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" + ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" + ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" + ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" + "add x15, x15, #0x8\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x14, x14, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x14, x17\n" + "blt 4b\n" + "5:" // K loop: Main loop: First: Tail + ".inst 0x25376143 // psel p3.b, p8.b/Z, p10.b[w15, #2]\n" + ".inst 0x253f6142 // psel p2.b, p8.b/Z, p10.b[w15, #3]\n" + ".inst 0x25776141 // psel p1.b, p8.b/Z, p10.b[w15, #6]\n" + ".inst 0x257f6140 // psel p0.b, p8.b/Z, p10.b[w15, #7]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0176f62 // ld1b { za0h.b[x15, #2] }, p3/Z, [x27, x23]\n" + ".inst 0x25266d23 // psel p3.b, p11.b/Z, p9.b[w14]\n" + "ldr x27, [x11, #0x0]\n" + "mov x13, #0x0\n" + ".inst 0xe0176b23 // ld1b { za0h.b[x15, #3] }, p2/Z, [x25, x23]\n" + ".inst 0x25266d22 // psel p2.b, p11.b/Z, p9.b[w14]\n" + "ldr x25, [x10, #0x0]\n" + "mov x12, #0x0\n" + ".inst 0xe0176706 // ld1b { za0h.b[x15, #6] }, p1/Z, [x24, x23]\n" + ".inst 0x252e6d21 // psel p1.b, p11.b/Z, p9.b[w14, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01762a7 // ld1b { za0h.b[x15, #7] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252e6d20 // psel p0.b, p11.b/Z, p9.b[w14, #1]\n" + "whilelt p9.b, x16, %x[width]\n" + ".inst 0xe0bfcec0 // st1w { za0v.s[x14] }, p3/Z, [x22, XZR, LSL #2]\n" + "incb x16\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a8cac4 // st1w { za1v.s[x14] }, p2/Z, [x22, x8, LSL #2]\n" + "incb x23\n" + "whilelt p8.b, x16, %x[width]\n" + ".inst 0xe0a9c6c1 // st1w { za0v.s[x14, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bcc2c5 // st1w { za1v.s[x14, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "cbz x17, 7f\n" + "6:" // K loop: Main loop: Second: Loop + ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" + ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" + ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" + ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" + ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" + ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" + ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" + ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" + "add x10, x10, #0x10\n" + ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + "add x13, x13, #0x8\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "add x12, x12, #0x2\n" + "addvl x22, x22, #4\n" + "cmp x12, x17\n" + "blt 6b\n" + "7:" // K loop: Main loop: Second: Tail + ".inst 0x25256143 // psel p3.b, p8.b/Z, p10.b[w13]\n" + ".inst 0x252d6142 // psel p2.b, p8.b/Z, p10.b[w13, #1]\n" + ".inst 0x25656141 // psel p1.b, p8.b/Z, p10.b[w13, #4]\n" + ".inst 0x256d6140 // psel p0.b, p8.b/Z, p10.b[w13, #5]\n" + "mov x11, %x[in]\n" + "add x10, %x[in], x8, LSL #3\n" + ".inst 0xe0172f60 // ld1b { za0h.b[x13] }, p3/Z, [x27, x23]\n" + ".inst 0x25246d23 // psel p3.b, p11.b/Z, p9.b[w12]\n" + "ldr x27, [x11, #0x0]\n" + ".inst 0xe0172b21 // ld1b { za0h.b[x13, #1] }, p2/Z, [x25, x23]\n" + ".inst 0x25246d22 // psel p2.b, p11.b/Z, p9.b[w12]\n" + "ldr x25, [x10, #0x0]\n" + ".inst 0xe0172704 // ld1b { za0h.b[x13, #4] }, p1/Z, [x24, x23]\n" + ".inst 0x252c6d21 // psel p1.b, p11.b/Z, p9.b[w12, #1]\n" + "ldr x24, [x11, #0x8]\n" + "add x11, x11, #0x10\n" + ".inst 0xe01722a5 // ld1b { za0h.b[x13, #5] }, p0/Z, [x21, x23]\n" + "ldr x21, [x10, #0x8]\n" + ".inst 0x252c6d20 // psel p0.b, p11.b/Z, p9.b[w12, #1]\n" + "whilelt p9.b, x16, %x[width]\n" + ".inst 0xe0bf8ec8 // st1w { za2v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + "subs x20, x20, #0x1\n" + "add x10, x10, #0x10\n" + ".inst 0xe0a88acc // st1w { za3v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "incb x16\n" + "incb x23\n" + ".inst 0xe0a986c9 // st1w { za2v.s[x12, #1] }, p1/Z, [x22, x9, LSL #2]\n" + ".inst 0xe0bc82cd // st1w { za3v.s[x12, #1] }, p0/Z, [x22, x28, LSL #2]\n" + "addvl x22, x22, #4\n" + "bgt 3b\n" + "8:" // K loop: Tails + "cbnz x26, 11f\n" + "mov x11, %x[in]\n" + "whilelt p8.b, x16, %x[width]\n" + "mov x13, #0x0\n" + "mov x12, #0x0\n" + "9:" // K loop: Tails: Even: First + ".inst 0x25306d23 // psel p3.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d22 // psel p2.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25356141 // psel p1.b, p8.b/Z, p10.b[w13, #2]\n" + ".inst 0x253d6140 // psel p0.b, p8.b/Z, p10.b[w13, #3]\n" + ".inst 0xe0bf8ec0 // st1w { za0v.s[x12] }, p3/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a88ac4 // st1w { za1v.s[x12] }, p2/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "ldr x21, [x11, #0x0]\n" + "cmp x12, x8\n" + "ldr x20, [x11, x8, LSL #0x3]\n" + "add x11, x11, #0x8\n" + ".inst 0xe01726a2 // ld1b { za0h.b[x13, #2] }, p1/Z, [x21, x23]\n" + ".inst 0xe0172283 // ld1b { za0h.b[x13, #3] }, p0/Z, [x20, x23]\n" + "add x13, x13, #0x4\n" + "blt 9b\n" + "whilelt p9.b, x16, %x[width]\n" + "whilelt p8.b, x16, %x[width]\n" + "mov x20, #0x0\n" + "mov x12, #0x0\n" + "10:" // K loop: Tails: Even: Second + ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" + "add x20, x20, #0x4\n" + ".inst 0xe0bf86c8 // st1w { za2v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882cc // st1w { za3v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 10b\n" + "whilelt p8.b, x16, %x[width]\n" + "b 13f\n" + "11:" // K loop: Tails: Odd + "mov x12, #0x0\n" + "12:" // K loop: Tails: Odd: Loop + ".inst 0x25306d21 // psel p1.s, p11.s/Z, p9.s[w12]\n" + ".inst 0x25306d20 // psel p0.s, p11.s/Z, p9.s[w12]\n" + ".inst 0xe0bf86c0 // st1w { za0v.s[x12] }, p1/Z, [x22, XZR, LSL #2]\n" + ".inst 0xe0a882c4 // st1w { za1v.s[x12] }, p0/Z, [x22, x8, LSL #2]\n" + "add x12, x12, #0x1\n" + "addvl x22, x22, #2\n" + "cmp x12, x7\n" + "blt 12b\n" + "13:" // K loop: End + "mov %x[out], x22\n" + ".inst 0xd503467f // SMSTOP\n" + : [out] "+&r"(out) + : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", + "p7", "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "x7", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", + "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", + "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); + out_base += m_step * kai_roundup(k_chunk_length, KR) * sizeof(int8_t); + } + } +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..7136d837aa68230e701292cc152c81a883aabba8 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h @@ -0,0 +1,60 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return Step size for row index +size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void); + +/// Gets the offset in bytes to the data element in the packed LHS buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t m_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed LHS buffer. +/// +/// @param[in] m Number of rows in the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// +/// @return The size in bytes of the packed LHS buffer. +size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); + +/// Pack the LHS matrix for use with indirect matrix multiplication +/// +/// @param[in] m Number of rows of the unpacked LHS matrix. +/// @param[in] k_chunk_count Number of LHS column splits. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of +/// t `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs +/// array, excluding zero pointers. +/// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is +/// not applied to this pointer when used in @ref lhs_ptrs. This can +/// be NULL if there is no padding used @ref lhs_ptrs +/// @param[out] lhs_packed Packed LHS matrix. +void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( + size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, + const void* pad_ptr, void* lhs_packed); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c new file mode 100644 index 0000000000000000000000000000000000000000..3db501b6ea27f69ea8759782b336eda598093053 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -0,0 +1,276 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. +#include "kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_input = sizeof(uint8_t); +static const size_t kai_num_bytes_output = sizeof(uint8_t); +static const size_t kai_num_bytes_bias = sizeof(int32_t); +static const size_t kai_num_bytes_scale = sizeof(float32_t); + +#define NR 2 +#define KR 4 +#define MAX_N_STEP (NR * KAI_SME_VEC_LENGTH_MAX_BYTES / KR) + +size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) { + return NR * kai_get_sme_vector_length_u8() / KR; +} + +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); + + return n_idx * kai_num_bytes_input; +} + +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_bias; +} + +size_t kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) { + return n_idx * kai_num_bytes_scale; +} + +static size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t k_chunk_count, size_t k_chunk_length) { + return kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() * + (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output + + kai_num_bytes_scale); +} + +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) { + KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(); + return block_idx * + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); +} + +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length) { + const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme()); + return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + n_rounded_up, k_chunk_count, k_chunk_length); +} + +void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params) { + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(bias != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(params != NULL); + + size_t height = k_chunk_length; + const size_t width = n; + const void* in = rhs; + void* out = rhs_packed; + const size_t in_stride = rhs_row_stride; + + KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP); + uint8_t pad_row[MAX_N_STEP]; + if (height % KR) { + memset(pad_row, 0, MAX_N_STEP * sizeof(uint8_t)); + } + + size_t out_stride = + kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length); + const int32_t input_zero_point = params->lhs_zero_point; + const float scale_multiplier = params->scale_multiplier; + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x12, %x[out]\n" + "mov x11, %x[k_chunk_count]\n" + "ptrue p2.b\n" + "incb %x[out], ALL, MUL #2\n" + "1:" // Chunk Loop + "mov x10, %x[height]\n" + "cmp x10, #0x8\n" + "blt 5f\n" + "2:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[out]\n" + "add x27, x9, %x[in_stride]\n" + "sub x10, x10, #0x8\n" + "add x26, x27, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x25, x26, %x[in_stride]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "3:" // Main row loop: Column loop + "whilelt p0.b, XZR, x24\n" + "decw x24, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "cmp x24, #0x0\n" + "incd x9, ALL, MUL #4\n" + "ld1b { z22.b }, p0/Z, [x27]\n" + "incd x27, ALL, MUL #4\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "incd x26, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x25]\n" + "incd x25, ALL, MUL #4\n" + "ld1b { z20.b }, p0/Z, [x23]\n" + "incd x23, ALL, MUL #4\n" + "ld1b { z19.b }, p0/Z, [x22]\n" + "zip1 z21.b, z18.b, z17.b\n" + "incd x22, ALL, MUL #4\n" + "ld1b { z18.b }, p0/Z, [x21]\n" + "zip1 z17.b, z22.b, z16.b\n" + "incd x21, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "incd x20, ALL, MUL #4\n" + "zip1 z20.b, z20.b, z18.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z19.b, z21.b, z17.b\n" + "zip2 z18.b, z21.b, z17.b\n" + "zip1 z17.b, z20.b, z16.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "st1b { z19.b }, p2, [x28]\n" + "st1b { z18.b }, p2, [x28, #1, MUL VL]\n" + "st1b { z17.b }, p2, [x28, #2, MUL VL]\n" + "st1b { z16.b }, p2, [x28, #3, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 3b\n" + "cmp x10, #0x8\n" + "addvl %x[out], %x[out], #4\n" + "bge 2b\n" + "cbz x10, 9f\n" + "5:" // Main loop skip + "6:" // Tail row loop: Head + "mov x9, %x[in]\n" + "cmp x10, #0x3\n" + "add x27, x9, %x[in_stride]\n" + "cntw x24, ALL, MUL #2\n" + "add x26, x27, %x[in_stride]\n" + "csel x23, x24, XZR, GT\n" + "add x25, x26, %x[in_stride]\n" + "csel x22, x24, XZR, GE\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x28, %x[out]\n" + "csel %x[in], %x[in], x25, GT\n" + "csel x25, x25, %x[pad_row], GT\n" + "csel %x[in], %x[in], x26, GE\n" + "csel x26, x26, %x[pad_row], GE\n" + "cmp x10, #0x1\n" + "sub x10, x10, #0x4\n" + "csel %x[in], %x[in], x27, GT\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x21, x24, XZR, GT\n" + "mov x20, %x[width]\n" + "7:" // Tail row loop: Column loop + "whilelt p0.b, XZR, x20\n" + "decw x20, ALL, MUL #2\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "cmp x20, #0x0\n" + "add x9, x9, x24\n" + "ld1b { z19.b }, p0/Z, [x27]\n" + "add x27, x27, x21\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "add x26, x26, x22\n" + "ld1b { z16.b }, p0/Z, [x25]\n" + "add x25, x25, x23\n" + "zip1 z18.b, z18.b, z17.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z16.b, z18.b, z16.b\n" + "st1b { z17.b }, p2, [x28]\n" + "st1b { z16.b }, p2, [x28, #1, MUL VL]\n" + "add x28, x28, %x[out_stride]\n" + "bgt 7b\n" + "cmp x10, #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 6b\n" + "9:" // Done + "sub x11, x11, #0x1\n" + "cbnz x11, 1b\n" + "mov x22, %x[out]\n" + "mov x21, %x[width]\n" + "dup z18.s, %w[scale_multiplier]\n" + "cbz %x[scale], 11f\n" + "10:" // Scale: Full loop + "mov x20, x21\n" + "decw x21, ALL, MUL #2\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[scale]]\n" + "cmp x21, #0x0\n" + "ld1w { z16.s }, p0/Z, [%x[scale], #1, MUL VL]\n" + "incb %x[scale], ALL, MUL #2\n" + "fmul z17.s, z17.s, z18.s\n" + "fmul z16.s, z16.s, z18.s\n" + "st1w { z17.s }, p2, [x22]\n" + "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Scale: Done + "cbz %x[width], 14f\n" + "cbz %x[height], 14f\n" + "dup z21.s, %w[input_zero_point]\n" + "add x25, %x[height], #0x3\n" + "cntw x24, ALL, MUL #2\n" + "mov z20.b, #0x1\n" + "lsr x25, x25, #0x2\n" + "mov x23, %x[width]\n" + "mul x25, %x[k_chunk_count], x25\n" + "addvl x22, x12, #2\n" + "neg z21.s, p2/M, z21.s\n" + "12:" // Bias: N loop + "mov x21, x22\n" + "mov x20, x25\n" + "mov z19.s, #0x0\n" + "mov z18.s, #0x0\n" + "13:" // Bias: K loop + "ld1b { z17.b }, p2/Z, [x21]\n" + "subs x20, x20, #0x1\n" + "ld1b { z16.b }, p2/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "sdot z19.s, z17.b, z20.b\n" + "sdot z18.s, z16.b, z20.b\n" + "bgt 13b\n" + "mov x20, x23\n" + "add x22, x22, %x[out_stride]\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [%x[bias]]\n" + "subs x23, x23, x24\n" + "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" + "addvl %x[bias], %x[bias], #2\n" + "mla z17.s, p2/M, z19.s, z21.s\n" + "mla z16.s, p2/M, z18.s, z21.s\n" + "st1w { z17.s }, p2, [x12]\n" + "st1w { z16.s }, p2, [x12, #1, MUL VL]\n" + "add x12, x12, %x[out_stride]\n" + "bgt 12b\n" + "14:" // Bias: Done + ".inst 0xd503467f // SMSTOP\n" + : [bias] "+&r"(bias), [in] "+&r"(in), [out] "+&r"(out), [scale] "+&r"(scale) + : [height] "r"(height), [in_stride] "r"(in_stride), [input_zero_point] "r"(input_zero_point), + [k_chunk_count] "r"(k_chunk_count), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), + [scale_multiplier] "r"(scale_multiplier), [width] "r"(width) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", + "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h new file mode 100644 index 0000000000000000000000000000000000000000..77b1dbde73f1273d9de6b1acf1009421898b2015 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -0,0 +1,90 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return Step size for column index. +size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void); + +/// Gets the offset in bytes to the data element in the RHS matrix buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the bias buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the scale buffer. +/// +/// @param[in] n_idx Column index. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx); + +/// Gets the offset in bytes to the data element in the packed RHS buffer. +/// +/// @param[in] n_idx Column index. Must be divisible by `n_step` +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n_idx, size_t k_chunk_count, size_t k_chunk_length); + +/// Gets the size in bytes of the packed RHS buffer. +/// +/// @param[in] n Number of columns. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// +/// @return The size in bytes of the packed RHS buffer. +size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length); + +/// Runs the RHS packing function for matrix multiplication. +/// +/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset +/// calculated using the following functions: +/// +/// * RHS: @ref kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Bias: @ref kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Scale: @ref kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// * Output: @ref kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme. +/// +/// @param[in] n Number of columns of the output matrix. +/// @param[in] k_chunk_count Number of chunks. +/// @param[in] k_chunk_length Number of rows in each chunk. +/// @param[in] rhs_row_stride Row stride in bytes of the RHS matrix. +/// @param[in] rhs RHS matrix data buffer. +/// @param[in] bias Bias matrix data buffer. +/// @param[in] scale Scale data buffer. +/// @param[out] rhs_packed Packed RHS matrix. +/// @param[in] params Extra packing parameters. +void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( + size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, + const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/common/test_suite.hpp b/test/common/test_suite.hpp index cc79103785148c5da7493334c21fe17d1a7c6c8f..1bc851774da8c695ae3574f71fe63d0054ce4c24 100644 --- a/test/common/test_suite.hpp +++ b/test/common/test_suite.hpp @@ -76,6 +76,22 @@ struct MatMulShape { size_t m{}; ///< LHS height. size_t n{}; ///< RHS width. size_t k{}; ///< LHS width and RHS height. +private: + friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) { + return // + lhs.m == rhs.m && // + lhs.n == rhs.n && // + lhs.k == rhs.k; + } +}; + +struct HashMatMulShape { + size_t operator()(const kai::test::MatMulShape& shape) const { + return // + (std::hash{}(shape.m) << 0) ^ // + (std::hash{}(shape.n) << 1) ^ // + (std::hash{}(shape.k) << 2); + } }; /// Matrix multiplication test information. diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index a5fd1e66b53f8cbbf3cddd86c5fa3a39289b91bf..e735b23f553730a61f30e02d86399a3d3b5b5e02 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -12,7 +12,6 @@ #include #include "kai/kai_common.h" -#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" @@ -185,6 +184,82 @@ std::vector matmul( return tmp_dst; } +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> +std::vector indirect_matmul_nt_t_quantized( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) { + const auto lhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, lhs_quant_width); + const auto rhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, rhs_quant_width); + + std::vector dst(m * n * sizeof(DstData)); + + for (size_t i_m = 0; i_m < m; ++i_m) { + for (size_t i_n = 0; i_n < n; ++i_n) { + DstData acc = 0; + + for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) { + // Calculate the K chunk pointer. Apply offset if this is not padding + const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk; + const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx]; + if (k_chunk_ptr != lhs_padding) { + k_chunk_ptr = reinterpret_cast(reinterpret_cast(k_chunk_ptr) + lhs_offset); + } + + for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) { + const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len; + + const auto lhs_data_index = i_k_chunk_len; + const auto lhs_quant_index = (i_m / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_value = read_array(k_chunk_ptr, lhs_data_index); + const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) + : static_cast(1); + const auto lhs_zero_point = lhs_zero_points != nullptr + ? read_array(lhs_zero_points, lhs_quant_index) + : static_cast(0); + + const auto rhs_data_index = i_n * (k_chunk_count * k_chunk_length) + i; + const auto rhs_quant_index = (i_n / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_value = read_array(rhs_data, rhs_data_index); + const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) + : static_cast(1); + const auto rhs_zero_point = rhs_zero_points != nullptr + ? read_array(rhs_zero_points, rhs_quant_index) + : static_cast(0); + + acc += (static_cast(lhs_value) - static_cast(lhs_zero_point)) * + static_cast(lhs_scale) * + (static_cast(rhs_value) - static_cast(rhs_zero_point)) * + static_cast(rhs_scale); + } + } + + if (bias_data != nullptr) { + const auto bias_value = read_array(bias_data, i_n); + const auto bias_scale = bias_scales != nullptr + ? read_array(bias_scales, i_n / bias_quant_width) + : static_cast(1); + const auto bias_zero_point = bias_zero_points != nullptr + ? read_array(bias_zero_points, i_n / bias_quant_width) + : static_cast(0); + + acc += (static_cast(bias_value) - static_cast(bias_zero_point)) * + static_cast(bias_scale); + } + + write_array(dst.data(), i_m * n + i_n, acc); + } + } + + return dst; +} + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> @@ -207,7 +282,7 @@ std::vector matmul_nt_t_quantized( for (size_t i = 0; i < k; ++i) { const auto lhs_data_index = row * k + i; - const auto lhs_quant_index = row / lhs_quant_height * lhs_num_quant_per_row + i / lhs_quant_width; + const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; const auto lhs_value = read_array(lhs_data, lhs_data_index); const auto lhs_scale = lhs_scales != nullptr ? read_array(lhs_scales, lhs_quant_index) : static_cast(1); @@ -216,7 +291,7 @@ std::vector matmul_nt_t_quantized( : static_cast(0); const auto rhs_data_index = col * k + i; - const auto rhs_quant_index = col / rhs_quant_height * rhs_num_quant_per_row + i / rhs_quant_width; + const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; const auto rhs_value = read_array(rhs_data, rhs_data_index); const auto rhs_scale = rhs_scales != nullptr ? read_array(rhs_scales, rhs_quant_index) : static_cast(1); @@ -259,6 +334,16 @@ matmul_nt_t_quantized +indirect_matmul_nt_t_quantized( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // + const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + template < typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 9a8ce9f809558a77f30c66912ba3dc6206fb596a..8d83e98c2455cb8e7c972d097586e1eefd55b5aa 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -122,7 +122,12 @@ std::vector matmul_clamp_nt_t( /// @param[in] m The LHS and output height. /// @param[in] n The RHS height and output width. /// @param[in] k The LHS and RHS width. +/// @param[in] k_chunk_count Number of K chunk pointers per row in lhs_idata matrix +/// @param[in] k_chunk_length Lenght of each K chunk pointed to in lhs_idata matrix /// @param[in] lhs_data The LHS data matrix. +/// @param[in] lhs_idata The indirect LHS data matrix. +/// @param[in] lhs_offset The indirection LHS data matrix offset, applied to non-padding pointers +/// @parma[in] lhs_padding The indirection LHS padding chunk pointer /// @param[in] lhs_scales The LHS quantization scales matrix. /// @param[in] lhs_zero_points The LHS quantization zero points matrix. /// @param[in] lhs_quant_width The LHS quantization block width. @@ -156,4 +161,16 @@ std::vector matmul_nt_t_quantized( size_t rhs_quant_width, // const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> +std::vector indirect_matmul_nt_t_quantized( + size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // + const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, + const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + } // namespace kai::test diff --git a/test/reference/reorder.cpp b/test/reference/reorder.cpp index 564f96f6ae6f9dbd4cd7fd7d24834bf079f6fe9e..61ba67d1cf2779ed9aa9cf7475c23bc00ec108dd 100644 --- a/test/reference/reorder.cpp +++ b/test/reference/reorder.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -46,5 +46,7 @@ std::vector reorder_block( template std::vector reorder_block( const void* src, size_t height, size_t width, size_t block_height, size_t block_width); +template std::vector reorder_block( + const void* src, size_t height, size_t width, size_t block_height, size_t block_width); } // namespace kai::test diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 2ffe5b7070d6ecfd116618eefceb8a8fe362ffa2..d1529ea0b04b637c7b52348ca238a64bccc47a03 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -9,14 +9,25 @@ #include #include #include +#include +#include +#include #include #include +#include #include +#include +#include +#include "kai/kai_common.h" +#include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" #include "test/common/cpu_info.hpp" #include "test/common/matrix_portion.hpp" @@ -37,6 +48,12 @@ namespace kai::test { using Buffer = std::vector; +using IndirectionBuffer = std::vector; + +struct KChunk { + size_t count; + size_t length; +}; struct LhsPackKernel { std::function get_m_step; @@ -49,6 +66,16 @@ struct LhsPackKernel { pack; }; +struct LhsPackIndirectKernel { + std::function get_m_step; + std::function get_packed_lhs_offset; + std::function get_packed_lhs_size; + std::function + pack; +}; + struct RhsPackKernel { std::function get_n_step; std::function get_rhs_offset; @@ -63,6 +90,19 @@ struct RhsPackKernel { pack; }; +struct RhsPackIndirectKernel { + std::function get_n_step; + std::function get_rhs_offset; + std::function get_bias_offset; + std::function get_scale_offset; + std::function get_packed_rhs_offset; + std::function get_packed_rhs_size; + std::function + pack; +}; + struct MatMulKernel { std::function get_m_step; std::function get_n_step; @@ -80,6 +120,19 @@ struct MatMulKernel { matmul; }; +struct MatMulIndirectKernel { + std::function get_m_step; + std::function get_n_step; + std::function get_packed_lhs_offset; + std::function get_packed_rhs_offset; + std::function get_dst_offset; + std::function get_dst_size; + std::function + imatmul; +}; + const static RhsPackKernel rhs_pack = { .get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, .get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, @@ -102,6 +155,18 @@ struct MatMulVariant { MatMulKernel matmul; ///< Matmul kernel interface }; +struct IndirectMatMulVariant { + std::string_view name; ///< Test identification + MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) + MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) + + std::function is_supported; ///< HW support check + + LhsPackIndirectKernel lhs_pack; ///< LHS packing kernel interface + RhsPackIndirectKernel rhs_pack; ///< RHS packing kernel interface + MatMulIndirectKernel matmul; ///< Matmul kernel interface +}; + const std::array gemm_variants = { MatMulVariant{ .name = "matmul_qai8_qai8p_qsi8cxp", @@ -146,6 +211,55 @@ const std::array gemm_variants = { }, }; +const std::array indirect_gemm_variants = { + IndirectMatMulVariant{ + .name = "indirect_matmul_qai8_qai8p_qsi8cxp", + .acc_pack{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + .acc_step{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + + .is_supported = cpu_has_sme2, + + .lhs_pack = + LhsPackIndirectKernel{ + .get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + .get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + .get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + .pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme, + }, + .rhs_pack = + RhsPackIndirectKernel{ + .get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_packed_rhs_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + }, + .matmul = + MatMulIndirectKernel{ + .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_packed_lhs_offset = + kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_packed_rhs_offset = + kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + }, + }, +}; + const std::array gemv_variants = { MatMulVariant{ .name = "matmul_qai8_qai8_qsi8cxp", @@ -185,8 +299,7 @@ const std::array gemv_variants = { }, }; -constexpr uint32_t seed = 0; ///< Random seed used for tests -constexpr float output_clamp_rate = 0.1F; ///< Clamping range in ration of output +constexpr uint32_t seed = 0; ///< Random seed used for tests /// Value range template @@ -215,6 +328,10 @@ struct TestReference { Buffer lhs_qai8; Buffer lhs_qai8_scales; Buffer lhs_qai8_zero_points; + IndirectionBuffer lhs_qai8_indirect; + Buffer lhs_qai8_indirect_packed; + Buffer lhs_qai8_indirect_padding; + size_t lhs_qai8_indirect_offset; Buffer rhs_qsi8; Buffer rhs_scales; @@ -242,16 +359,76 @@ static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel matmul_clamp_qai8_qai8_ .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, }; +/// Make sure that interface matches +static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel + imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface [[maybe_unused]] = { + .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, +}; + +static constexpr int8_t padding_value = 0; + +// Functionality for hashing generated test data. +// This is particularly useful for portion testing +// which reuses the exact same data for all portions +struct TestDataId { + MatMulShape shape; + MatMulShape shape_pack; + size_t chunk_len; + bool pad_testing; + float clamp_ratio; + +private: + friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { + return // + lhs.shape == rhs.shape && // + lhs.shape_pack == rhs.shape_pack && // + lhs.chunk_len == rhs.chunk_len && // + lhs.pad_testing == rhs.pad_testing && // + lhs.clamp_ratio == rhs.clamp_ratio; + } +}; + +struct HashTestDataId { + size_t operator()(const TestDataId& id) const { + return // + (HashMatMulShape{}(id.shape) << 0) ^ // + (HashMatMulShape{}(id.shape_pack) << 1) ^ // + (std::hash{}(id.chunk_len) << 2) ^ // + (std::hash{}(id.pad_testing) << 3) ^ // + (std::hash{}(id.clamp_ratio) << 4); + } +}; + +// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) +static std::unordered_map g_data; +// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) + /// Generate test reference data -static TestReference get_test_reference(const MatMulShape& shape, const MatMulVariant& variant) { +static const TestReference& get_test_reference(const TestDataId& test_data_id) { // ============================================================ // Generates input and reference output data // ============================================================ + // Attempt to find test data in cache + const auto data_it = g_data.find(test_data_id); + if (data_it != g_data.end()) { + return data_it->second; + } + + const auto& [shape, pack_shape, k_chunk_len, pad_testing, clamp_ratio] = test_data_id; + // Generates the input data in floating-point. - const auto lhs_f32 = fill_random(shape.m * shape.k, seed); - const auto rhs_f32 = fill_random(shape.k * shape.n, seed); - const auto bias_f32 = fill_random(shape.n, seed); + Buffer lhs_f32 = fill_random(shape.m * shape.k, seed); + const Buffer rhs_f32 = fill_random(shape.k * shape.n, seed); + const Buffer bias_f32 = fill_random(shape.n, seed); // Quantizes the input data. // * LHS: 8-bit asymmetric per-matrix quantization. @@ -265,6 +442,31 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); + const size_t k_chunk_count = shape.k / k_chunk_len; + assert(k_chunk_count * k_chunk_len == shape.k); + + // Setup an indirection buffer, where each "row" contains `k_chunk_count` + // pointers to chunks of length `k_chunk_len` in the input_buffer + IndirectionBuffer lhs_qai8_indirect(shape.m * k_chunk_count); + Buffer lhs_padding(k_chunk_len, padding_value); + for (size_t m_i = 0; m_i < shape.m; ++m_i) { + for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) { + const size_t idx = m_i * k_chunk_count + k_chunk_idx; + if (pad_testing and m_i == 0) { + // Push padding pointers for first row + lhs_qai8_indirect[idx] = lhs_padding.data(); + } else { + uintptr_t offset = m_i * shape.k + k_chunk_idx * k_chunk_len; + lhs_qai8_indirect[idx] = reinterpret_cast(offset); + } + } + } + const auto indirection_base = reinterpret_cast(lhs_qai8.data()); + + // Reorder indirection pointers to layout the packing kernel expectes + Buffer lhs_qai8_indirect_packed = reorder_block( + reinterpret_cast(lhs_qai8_indirect.data()), shape.m, k_chunk_count, pack_shape.m, 1); + // Transpose, then quantize symmetrically, then transpose back. This will give one // quantization value for each column const auto rhs_f32_t = transpose(rhs_f32.data(), shape.k, shape.n); @@ -280,10 +482,12 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa quantize_symmetric_per_block(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); // Runs the reference implementation of matmul to produce floating-point result. + const void* const* lhs_iptr = reinterpret_cast(lhs_qai8_indirect.data()); const auto ref_dst_f32 = - matmul_nt_t_quantized( - shape.m, shape.n, shape.k, // matmul shape - lhs_qai8.data(), &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point + indirect_matmul_nt_t_quantized( + shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape + lhs_iptr, indirection_base, lhs_padding.data(), // LHS indirection, offset and padding + &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point shape.m, shape.k, // LHS quantization window shape rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors 1, shape.k, // RHS quantization window shape @@ -309,8 +513,8 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa const auto ref_dst_f32_max = reduce_max(ref_dst_f32.data(), shape.m * shape.n); const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min; - const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * output_clamp_rate / 2; - const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * output_clamp_rate / 2; + const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * clamp_ratio / 2; + const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * clamp_ratio / 2; const auto dst_qai8_clamp_min = quantize_asymmetric(ref_dst_f32_clamp_min, dst_scale, dst_zero_point); const auto dst_qai8_clamp_max = @@ -330,12 +534,12 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa // The reference packing functions cannot be executed earlier // because we need the reference floating-point output first to have // the quantization information. - auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k); + auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, pack_shape.m, pack_shape.k); auto packed_rhs = matmul_pack_rhs_nxk_static_quantized( rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, - variant.acc_pack.n, variant.acc_pack.k); + pack_shape.n, pack_shape.k); - return { + const TestReference& reference = g_data[test_data_id] = { .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, .qa_lhs = {.scale = lhs_scale, .zero_point = lhs_zero_point}, @@ -344,6 +548,10 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa .lhs_qai8 = std::move(lhs_qai8), .lhs_qai8_scales = std::move(lhs_qai8_scales), .lhs_qai8_zero_points = std::move(lhs_qai8_zero_points), + .lhs_qai8_indirect = std::move(lhs_qai8_indirect), + .lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed), + .lhs_qai8_indirect_padding = std::move(lhs_padding), + .lhs_qai8_indirect_offset = indirection_base, .rhs_qsi8 = std::move(rhs_qsi8), .rhs_scales = std::move(rhs_scales), @@ -355,6 +563,7 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa .packed_lhs = std::move(packed_lhs), .packed_rhs = std::move(packed_rhs), }; + return reference; } /// Test LHS packing @@ -432,6 +641,39 @@ static void test_rhs_pack( ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; } +static void compare_matmul_result( + const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) { + size_t mismatches = 0; + bool printed_row = false; + std::ostringstream sstream; + for (size_t m_i = 0; m_i < shape.m; ++m_i) { + for (size_t n_i = 0; n_i < shape.n; ++n_i) { + const auto i = m_i * shape.n + n_i; + const auto in_area = m_i >= output_area.start_row() && m_i < output_area.end_row() && + n_i >= output_area.start_col() && n_i < output_area.end_col(); + + const auto imp_value = read_array(actual.data(), i); + const auto ref_value = in_area ? read_array(reference.data(), i) : 0; + const auto error = std::abs(imp_value - ref_value); + const auto threshold = in_area ? 1 : 0; + const bool mismatch = error > threshold; + if (mismatch) { + if (not printed_row) { + sstream << " row=" << m_i << ", columns: "; + printed_row = true; + } + sstream << n_i << ", "; + } + mismatches += static_cast(mismatch); + } + if (printed_row) { + sstream << "\n"; + } + printed_row = false; + } + ASSERT_EQ(mismatches, 0) << "Mismatches between reference result and actual result:\n" << sstream.str(); +} + /// Test MatMul of GEMM/GEMV like kernel static void test_matmul( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { @@ -461,48 +703,53 @@ static void test_matmul( reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), sizeof(int8_t), &imp_main_params); - size_t mismatches = 0; - for (size_t y = 0; y < shape.m; ++y) { - for (size_t x = 0; x < shape.n; ++x) { - const auto i = y * shape.n + x; - const auto in_area = y >= output_area.start_row() && y < output_area.end_row() && - x >= output_area.start_col() && x < output_area.end_col(); - - const auto imp_value = read_array(imp_dst.data(), i); - const auto ref_value = in_area ? read_array(reference.dst_qsi8_clamped.data(), i) : 0; - const auto error = std::abs(imp_value - ref_value); - const auto threshold = in_area ? 1 : 0; - - mismatches += static_cast(error > threshold); - } - } - ASSERT_EQ(mismatches, 0) << "There are mismatched between reference result actual result"; + compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped); } -using ThisTest = testing::TestWithParam>; +using MatMulQuantizedTest = testing::TestWithParam>; +using IndirectMatMulQuantizedTest = + testing::TestWithParam>; static std::string test_description( const MatMulVariant& variant, // const MatMulShape& shape, // - const MatrixPortion& portion) { + const MatrixPortion& portion, float clamp_ratio) { std::stringstream sstream; sstream << "Method_" << variant.name << "__M_" // << shape.m << "__N_" << shape.n << "__K_" << shape.k // << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << "__clamp_ratio_" << static_cast(clamp_ratio * 100); return sstream.str(); }; -TEST_P(ThisTest, EndToEnd) { - const auto& [variant, shape, output_portion] = GetParam(); +static std::string test_description( + const IndirectMatMulVariant& variant, // + const MatMulShape& shape, // + const MatrixPortion& portion, size_t k_chunk_len, float clamp_ratio) { + std::stringstream sstream; + sstream << "Method_" << variant.name << "__M_" // + << shape.m << "__N_" << shape.n << "__k_chunk_count_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << "__k_chunk_len_" << k_chunk_len // + << "__clamp_ratio_" << static_cast(clamp_ratio * 100); + return sstream.str(); +}; + +TEST_P(MatMulQuantizedTest, EndToEnd) { + const auto& [variant, shape, output_portion, clamp_ratio] = GetParam(); if (!variant.is_supported()) { GTEST_SKIP() << "CPU features are not supported by current CPU"; } - TestReference reference = get_test_reference(shape, variant); + TestDataId test_data_id{shape, variant.acc_pack, shape.k, false, clamp_ratio}; + const TestReference& reference = get_test_reference(test_data_id); // Check scheduling parameters const auto imp_mr = variant.matmul.get_mr(); @@ -532,56 +779,204 @@ TEST_P(ThisTest, EndToEnd) { test_matmul(shape, variant, matmul_portion, reference); } +namespace imatmul { + +/// Perform LHS IMATMUL packing +static Buffer lhs_pack( + const LhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t m, + const KChunk& k_chunk) { + const void* const* indirection_pointer = + reinterpret_cast(reference.lhs_qai8_indirect_packed.data()); + + // Allocate buffer + const size_t dst_size = variant.get_packed_lhs_size(m, k_chunk.count, k_chunk.length); + Buffer packed(dst_size); + + // Calculate offsets + const size_t input_offset = portion.start_row() * k_chunk.count; + const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + + variant.pack( + portion.height(), k_chunk.count, k_chunk.length, // Dimensions + indirection_pointer + input_offset, // Indirection input + reference.lhs_qai8_indirect_offset, // chunk offset + reference.lhs_qai8_indirect_padding.data(), // padding pointer + packed.data() + dst_offset); + + return packed; +} + +/// Perform RHS IMATMUL packing +static Buffer rhs_pack( + const RhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t n, + const KChunk& k_chunk) { + // Allocate output buffer + const size_t dst_size = variant.get_packed_rhs_size(n, k_chunk.count, k_chunk.length); + Buffer packed_all(dst_size); + Buffer packed(dst_size); + + // Caluclate effective quantization parameters + const kai_rhs_pack_qsi8cx_params quantization{ + reference.qa_lhs.zero_point, + reference.qa_lhs.scale / reference.qa_dst.scale, + }; + + // Calculate offsets + const size_t rhs_offset = variant.get_rhs_offset(portion.start_col()); + const size_t bias_offset = variant.get_bias_offset(portion.start_col()); + const size_t scale_offset = variant.get_scale_offset(portion.start_col()); + const size_t dst_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + + // Pack + variant.pack( + portion.width(), k_chunk.count, k_chunk.length, + n * sizeof(uint8_t), // Dimensions, row stride + reference.rhs_qsi8.data() + rhs_offset, // RHS matrix + reference.bias_qsi32.data() + bias_offset, // Bias + reference.rhs_scales.data() + scale_offset, // Scales + packed.data() + dst_offset, // Output + &quantization); + + return packed; +} + +/// Calculate the matmul result from IMATMUL kernels +static Buffer matmul( + const MatMulIndirectKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs, + const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { + // Calculate portion offsets. + size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); + size_t lhs_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); + size_t rhs_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); + + // Allocate output buffer + const size_t dst_size = variant.get_dst_size(shape.m, shape.n); + Buffer dst(dst_size); + + // Calculate geffective uantization parameters + kai_matmul_requantize32_params requantization{ + .min_value = reference.clamp.min, + .max_value = reference.clamp.max, + .output_zero_point = reference.qa_dst.zero_point, + }; + + // Call matmul kernel + variant.imatmul( + portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions + packed_lhs.data() + lhs_offset, // LHS + packed_rhs.data() + rhs_offset, // RHS + dst.data() + dst_offset, // DST + shape.n * sizeof(uint8_t), &requantization); + + return dst; +} +} // namespace imatmul + +TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { + /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count + * so instead of inventing a new special kind of shape, simply multiply + * with `k_chunk_len` here */ + const auto& [variant, shape_k_chunk, output_portion, k_chunk_len, clamp_ratio] = GetParam(); + const KChunk k_chunk{shape_k_chunk.k, k_chunk_len}; + MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length}; + + if (!variant.is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + // Toggle padding testst when LHS has more than one row + TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1, clamp_ratio}; + const TestReference& reference = get_test_reference(test_data_id); + const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); + + Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk); + Buffer packed_rhs = imatmul::rhs_pack(variant.rhs_pack, portion, reference, shape.n, k_chunk); + Buffer impl_result = imatmul::matmul(variant.matmul, portion, reference, packed_lhs, packed_rhs, shape, k_chunk); + compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped); +} + +static constexpr std::array shapes{ + // clang-format off + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 16, 4}, + MatMulShape{ 1, 16, 16}, + MatMulShape{ 1, 17, 4}, + MatMulShape{ 1, 19, 24}, + MatMulShape{ 1, 32, 4}, + MatMulShape{ 1, 32, 32}, + MatMulShape{ 1, 33,200}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 1, 64, 4}, + MatMulShape{ 1, 65, 4}, + MatMulShape{ 1, 300, 10}, + MatMulShape{ 1, 512, 4}, + MatMulShape{ 1, 1523, 10}, + MatMulShape{ 2, 195, 50}, + MatMulShape{ 3, 6, 6}, + MatMulShape{ 3, 28, 25}, + MatMulShape{ 3, 184,177}, + MatMulShape{ 4, 16, 27}, + MatMulShape{ 5, 136, 23}, + MatMulShape{ 6, 18, 31}, + MatMulShape{ 6, 28, 1}, + MatMulShape{ 6, 29, 24}, + MatMulShape{ 16, 16, 4}, + MatMulShape{ 20, 30, 40}, + MatMulShape{ 23, 1, 43}, + MatMulShape{ 32, 14, 1}, + MatMulShape{ 32, 16, 27}, + MatMulShape{ 32, 32, 3}, + MatMulShape{ 32, 32, 4}, + MatMulShape{ 33, 29, 24}, + MatMulShape{ 64, 64, 3}, + MatMulShape{ 64, 64, 4}, + MatMulShape{ 96, 96, 3}, + MatMulShape{123, 85, 45}, + MatMulShape{128, 128, 3}, + MatMulShape{130, 130, 6}, + // clang-format on +}; + INSTANTIATE_TEST_SUITE_P( - matmul_clamp_qai8_qai8p_qsi8cxp, ThisTest, + matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest, testing::Combine( - testing::ValuesIn(gemm_variants), - testing::ValuesIn({ - // clang-format off - MatMulShape{ 1, 1, 1}, - MatMulShape{ 1, 49, 21}, - MatMulShape{ 16, 16, 4}, - MatMulShape{ 20, 30, 40}, - MatMulShape{ 23, 1, 43}, - MatMulShape{ 32, 14, 1}, - MatMulShape{ 32, 32, 4}, - MatMulShape{ 64, 64, 4}, - MatMulShape{123, 85, 45}, - MatMulShape{130, 130, 6}, - // clang-format on - }), + testing::ValuesIn(gemm_variants), testing::ValuesIn(shapes), testing::ValuesIn({ // clang-format off MatrixPortion( 0, 0, 1, 1), // Full matrix. MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner. MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. // clang-format on - })), + }), + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), [](const auto& info) -> std::string { return test_description( std::get(info.param), // std::get(info.param), // - std::get(info.param)); + std::get(info.param), // + std::get(info.param)); }); INSTANTIATE_TEST_SUITE_P( - matmul_clamp_qai8_qai8_qsi8cxp, ThisTest, + matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest, testing::Combine( testing::ValuesIn(gemv_variants), testing::ValuesIn({ // clang-format off - MatMulShape{1, 1, 1}, - MatMulShape{1, 16, 4}, - MatMulShape{1, 16, 16}, - MatMulShape{1, 17, 4}, - MatMulShape{1, 32, 4}, - MatMulShape{1, 32, 32}, - MatMulShape{1, 33, 200}, - MatMulShape{1, 64, 4}, - MatMulShape{1, 65, 4}, - MatMulShape{1, 300, 10}, - MatMulShape{1, 512, 4}, - MatMulShape{1, 1523, 10}, + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 16, 4}, + MatMulShape{ 1, 16, 16}, + MatMulShape{ 1, 17, 4}, + MatMulShape{ 1, 19, 24}, + MatMulShape{ 1, 32, 4}, + MatMulShape{ 1, 32, 32}, + MatMulShape{ 1, 33,200}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 1, 64, 4}, + MatMulShape{ 1, 65, 4}, + MatMulShape{ 1, 300, 10}, + MatMulShape{ 1, 512, 4}, + MatMulShape{ 1, 1523, 10}, // clang-format on }), testing::ValuesIn({ @@ -591,11 +986,42 @@ INSTANTIATE_TEST_SUITE_P( MatrixPortion(0, 0, 1, .5), // Left half MatrixPortion(0, .25, 1, .5) // Middle half // clang-format on - })), + }), + // Clamp range + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), [](const auto& info) -> std::string { return test_description( std::get(info.param), // std::get(info.param), // - std::get(info.param)); + std::get(info.param), // + std::get(info.param)); + }); + +INSTANTIATE_TEST_SUITE_P( + indirect_matmul_clamp_qai8_qai8p_qsi8cxp, IndirectMatMulQuantizedTest, + testing::Combine( + testing::ValuesIn(indirect_gemm_variants), testing::ValuesIn(shapes), + testing::ValuesIn({ + // clang-format off + // (Start row , start col , height , width) + MatrixPortion( 0 , 0 , 1 , 1) , // Full matrix. + MatrixPortion( 0 , 0 , 1 , 0.5) , // Left half + MatrixPortion( 0 , 0 , 0.5 , 1) , // Upper half + MatrixPortion( 0 , 0.5 , 1 , 0.5) , // Right half + MatrixPortion( 0.5 , 0 , 0.5 , 1) , // Bottom half + MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3) , // Center ninth + // clang-format on + }), + // k_chunk_len + testing::ValuesIn(std::initializer_list{1, 2, 3, 4, 8, 11, 32}), + // Clamp range + testing::ValuesIn(std::initializer_list{0.0F, 0.1F, 0.5F})), + [](const auto& info) -> std::string { + return test_description( + std::get(info.param), // + std::get(info.param), // + std::get(info.param), // + std::get(info.param), // + std::get(info.param)); }); } // namespace kai::test