diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bfa81b6ff61c13e94062a73b5af544fe2c490f4..fb3e8f441e95ac705320bd54bb2c403672946787 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Optimizations for FEAT_DotProd. - New 1x8 block size variant of matrix multiplication of QAI8DXP LHS and QSI4C32P RHS with F32 output. - Optimizations for FEAT_DotProd. +- New SME2 micro-kernels: + - Matrix multiplication (1xN) of QAI8 LHS and QSI8 RHS to produce QAI8 output. - Added demonstration of integration using CMake in F16 Arm® Neon™ matrix multiplication example. - Fixes: - Fix the RHS packing micro-kernel kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to handle null bias. diff --git a/CMakeLists.txt b/CMakeLists.txt index 19a7c8635c6bef3240a87d8c3d7f1160f8a07c5f..037bd611dca3a807f4bb2943d5c11195497398fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -194,6 +194,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 54b2b78574bf13a7effcaa6dcaa9b6eaeaff1259..4fa64cc9043964ca14b20a106d1bf0ac98d4a414 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -144,6 +144,7 @@ SME2_KERNELS = [ "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", + "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c new file mode 100644 index 0000000000000000000000000000000000000000..288c207f544cd72af446f144ca8a2fc6425daa4a --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c @@ -0,0 +1,960 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_nr = 2; +static const size_t kai_n_step = 16; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx == 0); + + return m_idx * k; +} + +static size_t kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t k) { + return kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() * + (kai_roundup(k, kai_kr) * sizeof(int8_t) + sizeof(int32_t) + sizeof(int32_t)); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(k); +} + +size_t kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() == 0); + + return (m_idx * dst_stride) + (n_idx * sizeof(int8_t)); +} + +size_t kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m, size_t n) { + return m * n * sizeof(int8_t); +} + +void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params) { + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(lhs_stride); + KAI_ASSUME(m == 1); + + typedef struct { + int32_t c_offset; + int32_t maxval; + int32_t minval; + } KernelArgs; + + KernelArgs k_args; + k_args.maxval = params->max_value; + k_args.minval = params->min_value; + k_args.c_offset = params->output_zero_point; + + size_t N = n; + size_t K = k; + + const void* A_ptr = lhs; + const void* B_ptr = rhs_packed; + void* output_ptr = dst; + + uint64_t flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x8, #0x0\n" + "mov x16, %x[B_ptr]\n" + "cntw x15, ALL, MUL #4\n" + "mov x14, %x[output_ptr]\n" + "add x13, %x[N], x15\n" + "ptrue p2.b\n" + "sub x13, x13, #0x1\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "udiv x13, x13, x15\n" + "mov x22, #0x1\n" + "add x21, x13, #0x3\n" + "and x21, x21, #0xfffffffffffffffc\n" + "mul x21, x21, x15\n" + "mul x21, x21, %x[K]\n" + "1:" // RHS size check loop + "cmp x21, #0x200000\n" + "blt 2f\n" + "tbnz x21, #0, 3f\n" + "lsr x21, x21, #0x1\n" + "lsl x22, x22, #0x1\n" + "b 1b\n" + "2:" // RHS do prefetch + "lsl x20, x21, #0x26\n" + "sub x22, x22, #0x1\n" + "lsl x22, x22, #0x16\n" + "orr x21, x21, x20\n" + "orr x21, x21, x22\n" + ".inst 0xf8b54a1a // rprfm pldonce, x21, [x16]\n" + "3:" // RHS prefetch exit + "add x12, %x[K], #0x3\n" + "cntw x20, ALL, MUL #2\n" + "mov z25.s, #0x0\n" + "mov z27.b, #0x1\n" + "bic x12, x12, #0x3\n" + "bic %x[flags], %x[flags], #0x80000000\n" + "add x12, x12, #0x8\n" + "mul x12, x12, x20\n" + "4:" // Column loop + "cmp x13, #0x4\n" + "bge 25f\n" + "cmp x13, #0x2\n" + "bgt 18f\n" + "beq 11f\n" + "cntw x20, ALL, MUL #2\n" + "add x23, x16, x12\n" + ".inst 0xa0404210 // ld1w { z16.s-z17.s }, pn8.b/Z, [x16]\n" + "cmp %x[N], x20\n" + "mov x11, %x[K]\n" + "csel x23, x23, x16, GT\n" + "mov x21, %x[N]\n" + ".inst 0xa04042f2 // ld1w { z18.s-z19.s }, pn8.b/Z, [x23]\n" + "mov x10, %x[A_ptr]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + "addvl x16, x16, #2\n" + "addvl x23, x23, #2\n" + ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" + "ble 7f\n" + "5:" // Width 1: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d93a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[0]\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d96a0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1]\n" + ".inst 0xc15d9a20 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[2]\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 6f\n" + "sdot z25.s, z13.b, z27.b\n" + "6:" // Width 1: Multiply loop: unique 1: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 5b\n" + "7:" // Width 1: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002e7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d90a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0]\n" + "ble 8f\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d97a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1]\n" + "ble 8f\n" + ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9920 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2]\n" + "ble 8f\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + "8:" // Width 1: Multiply loop: multiply skip + "tbnz %x[flags], #31, 9f\n" + "9:" // Width 1: Multiply loop: unique 2: skip row sum + ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" + ".inst 0xa040421e // ld1w { z30.s-z31.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa04042f8 // ld1w { z24.s-z25.s }, pn8.b/Z, [x23]\n" + "add x20, %x[k_args], %[maxval]\n" + "ld1rw { z2.s }, p2/Z, [x22]\n" + "ld1rw { z13.s }, p2/Z, [x21]\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + "ld1rw { z20.s }, p2/Z, [x20]\n" + "fmul z8.s, z8.s, z30.s\n" + "fmul z9.s, z9.s, z31.s\n" + "fmul z10.s, z10.s, z24.s\n" + "fmul z11.s, z11.s, z25.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1a2ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z2.s\n" + ".inst 0xc1b4cda8 // sclamp { z8.s-z11.s }, z13.s, z20.s\n" + "uzp1 z8.h, z8.h, z9.h\n" + "uzp1 z0.h, z10.h, z11.h\n" + "uzp1 z8.b, z8.b, z0.b\n" + "st1b { z8.b }, p1, [x14]\n" + "b 32f\n" + "11:" // Width 2 + "add x24, x16, x12, LSL #1\n" + "cntw x20, ALL, MUL #6\n" + ".inst 0xa0404214 // ld1w { z20.s-z21.s }, pn8.b/Z, [x16]\n" + "add x22, x24, x12\n" + "cmp %x[N], x20\n" + ".inst 0xa040430c // ld1w { z12.s-z13.s }, pn8.b/Z, [x24]\n" + "add x23, x16, x12\n" + "csel x22, x22, x16, GT\n" + ".inst 0xa04042f6 // ld1w { z22.s-z23.s }, pn8.b/Z, [x23]\n" + "mov x11, %x[K]\n" + "sub x21, %x[N], x15\n" + ".inst 0xa04042ce // ld1w { z14.s-z15.s }, pn8.b/Z, [x22]\n" + "mov x10, %x[A_ptr]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xc0040e80 // mova za.d[x8, #0], { z20.d-z23.d }\n" + "addvl x23, x23, #2\n" + "addvl x24, x24, #2\n" + ".inst 0xc0040d81 // mova za.d[x8, #1], { z12.d-z15.d }\n" + "addvl x22, x22, #2\n" + "ble 14f\n" + "12:" // Width 2: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400305 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" + ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d90a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0]\n" + ".inst 0xa0400311 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9520 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[1]\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9621 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1]\n" + ".inst 0xa0400305 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9820 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2]\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d98a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xa0400309 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + ".inst 0xc15d9d21 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 13f\n" + "sdot z25.s, z13.b, z27.b\n" + "13:" // Width 2: Multiply loop: unique 3: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 12b\n" + "14:" // Width 2: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400301 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" + ".inst 0xc15d9021 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[0]\n" + "ble 15f\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040031d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d96a0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1]\n" + ".inst 0xc15d97a1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[1]\n" + "ble 15f\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400319 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002db // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9820 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2]\n" + ".inst 0xc15d9b21 // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[2]\n" + "ble 15f\n" + ".inst 0xa0400219 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002fb // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040031d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24]\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + ".inst 0xc15d9f20 // sdot za.s[x8, 0], { z24.b-z27.b }, z13.b[3]\n" + ".inst 0xc15d9fa1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3]\n" + "15:" // Width 2: Multiply loop: multiply skip + "tbnz %x[flags], #31, 16f\n" + "16:" // Width 2: Multiply loop: unique 4: skip row sum + ".inst 0xc0060c00 // mova { z0.d-z3.d }, za.d[x8, #0]\n" + ".inst 0xa0404208 // ld1w { z8.s-z9.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa04042fe // ld1w { z30.s-z31.s }, pn8.b/Z, [x23]\n" + "add x20, %x[k_args], %[maxval]\n" + ".inst 0xc0060c24 // mova { z4.d-z7.d }, za.d[x8, #1]\n" + "add x16, x16, x12, LSL #1\n" + "ld1rw { z14.s }, p2/Z, [x22]\n" + "add x23, x23, x12, LSL #1\n" + "ld1rw { z11.s }, p2/Z, [x21]\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + "ld1rw { z10.s }, p2/Z, [x20]\n" + "fmul z0.s, z0.s, z8.s\n" + "fmul z1.s, z1.s, z9.s\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "fmul z2.s, z2.s, z30.s\n" + "fmul z3.s, z3.s, z31.s\n" + ".inst 0xc1b8e000 // frintn { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc131e000 // fcvtzs { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc1aeab00 // add { z0.s-z3.s }, { z0.s-z3.s }, z14.s\n" + ".inst 0xc1aacd60 // sclamp { z0.s-z3.s }, z11.s, z10.s\n" + "uzp1 z0.h, z0.h, z1.h\n" + "uzp1 z16.h, z2.h, z3.h\n" + "uzp1 z0.b, z0.b, z16.b\n" + "st1b { z0.b }, p2, [x14]\n" + ".inst 0xa1404217 // ld1w { z23.s, z31.s }, pn8.b/Z, [x16]\n" + ".inst 0xa14042f6 // ld1w { z22.s, z30.s }, pn8.b/Z, [x23]\n" + "fmul z4.s, z4.s, z23.s\n" + "fmul z5.s, z5.s, z31.s\n" + "fmul z6.s, z6.s, z22.s\n" + "fmul z7.s, z7.s, z30.s\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1aeab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z14.s\n" + ".inst 0xc1aacd64 // sclamp { z4.s-z7.s }, z11.s, z10.s\n" + "uzp1 z4.h, z4.h, z5.h\n" + "uzp1 z2.h, z6.h, z7.h\n" + "uzp1 z4.b, z4.b, z2.b\n" + "st1b { z4.b }, p1, [x14, #1, MUL VL]\n" + "b 32f\n" + "18:" // Width 3 + "add x26, x16, x12, LSL #2\n" + "cntw x20, ALL, MUL #10\n" + ".inst 0xa0404210 // ld1w { z16.s-z17.s }, pn8.b/Z, [x16]\n" + "add x25, x16, x12, LSL #1\n" + "add x24, x26, x12\n" + ".inst 0xa040435c // ld1w { z28.s-z29.s }, pn8.b/Z, [x26]\n" + "cmp %x[N], x20\n" + "add x23, x16, x12\n" + ".inst 0xa040432c // ld1w { z12.s-z13.s }, pn8.b/Z, [x25]\n" + "add x22, x25, x12\n" + "csel x24, x24, x16, GT\n" + ".inst 0xa04042f2 // ld1w { z18.s-z19.s }, pn8.b/Z, [x23]\n" + "mov x20, #0x2\n" + ".inst 0xa04042ce // ld1w { z14.s-z15.s }, pn8.b/Z, [x22]\n" + "mov x11, %x[K]\n" + ".inst 0xa040431e // ld1w { z30.s-z31.s }, pn8.b/Z, [x24]\n" + "msub x21, x15, x20, %x[N]\n" + "mov x10, %x[A_ptr]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + ".inst 0xc0040d81 // mova za.d[x8, #1], { z12.d-z15.d }\n" + "addvl x16, x16, #2\n" + "addvl x23, x23, #2\n" + ".inst 0xc0040f82 // mova za.d[x8, #2], { z28.d-z31.d }\n" + "addvl x25, x25, #2\n" + "addvl x22, x22, #2\n" + "addvl x26, x26, #2\n" + "addvl x24, x24, #2\n" + "ble 21f\n" + "19:" // Width 3: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400329 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400351 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9020 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9121 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0]\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9222 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[0]\n" + ".inst 0xa0400321 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d97a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9421 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[1]\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d94a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1]\n" + ".inst 0xa040033d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9ba1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2]\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xa0400335 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002d7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400351 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9ea1 // sdot za.s[x8, 1], { z20.b-z23.b }, z13.b[3]\n" + ".inst 0xc15d9e22 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 20f\n" + "sdot z25.s, z13.b, z27.b\n" + "20:" // Width 3: Multiply loop: unique 5: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 19b\n" + "21:" // Width 3: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400325 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa040035d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" + "addvl x26, x26, #2\n" + ".inst 0xa040031f // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d90a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0]\n" + ".inst 0xc15d93a2 // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0]\n" + "ble 22f\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400329 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9620 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[1]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9521 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[1]\n" + ".inst 0xc15d94a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1]\n" + "ble 22f\n" + ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040033d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400359 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9920 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2]\n" + "addvl x26, x26, #2\n" + ".inst 0xa040031b // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9ba1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2]\n" + ".inst 0xc15d9b22 // sdot za.s[x8, 2], { z24.b-z27.b }, z13.b[2]\n" + "ble 22f\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400339 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x25]\n" + ".inst 0xa04002db // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22]\n" + ".inst 0xa0400355 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + ".inst 0xa0400317 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x24]\n" + ".inst 0xc15d9f21 // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[3]\n" + ".inst 0xc15d9ea2 // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3]\n" + "22:" // Width 3: Multiply loop: multiply skip + "tbnz %x[flags], #31, 23f\n" + "23:" // Width 3: Multiply loop: unique 6: skip row sum + ".inst 0xc0060c18 // mova { z24.d-z27.d }, za.d[x8, #0]\n" + ".inst 0xa0404202 // ld1w { z2.s-z3.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa04042e6 // ld1w { z6.s-z7.s }, pn8.b/Z, [x23]\n" + "add x20, %x[k_args], %[maxval]\n" + ".inst 0xc0060c28 // mova { z8.d-z11.d }, za.d[x8, #1]\n" + "add x16, x16, x12, LSL #1\n" + "ld1rw { z0.s }, p2/Z, [x22]\n" + "add x23, x23, x12, LSL #1\n" + ".inst 0xc0060c5c // mova { z28.d-z31.d }, za.d[x8, #2]\n" + "ld1rw { z19.s }, p2/Z, [x21]\n" + ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" + "ld1rw { z18.s }, p2/Z, [x20]\n" + "fmul z24.s, z24.s, z2.s\n" + "fmul z25.s, z25.s, z3.s\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + "fmul z26.s, z26.s, z6.s\n" + "fmul z27.s, z27.s, z7.s\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1b8e318 // frintn { z24.s-z27.s }, { z24.s-z27.s }\n" + ".inst 0xc131e318 // fcvtzs { z24.s-z27.s }, { z24.s-z27.s }\n" + ".inst 0xc1a0ab18 // add { z24.s-z27.s }, { z24.s-z27.s }, z0.s\n" + ".inst 0xc1b2ce78 // sclamp { z24.s-z27.s }, z19.s, z18.s\n" + "uzp1 z24.h, z24.h, z25.h\n" + "uzp1 z16.h, z26.h, z27.h\n" + "uzp1 z24.b, z24.b, z16.b\n" + "st1b { z24.b }, p2, [x14]\n" + ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" + "add x16, x16, x12, LSL #1\n" + ".inst 0xa14042f1 // ld1w { z17.s, z25.s }, pn8.b/Z, [x23]\n" + "add x23, x23, x12, LSL #1\n" + "fmul z8.s, z8.s, z7.s\n" + "fmul z9.s, z9.s, z15.s\n" + "fmul z10.s, z10.s, z17.s\n" + "fmul z11.s, z11.s, z25.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc1b2ce68 // sclamp { z8.s-z11.s }, z19.s, z18.s\n" + "uzp1 z8.h, z8.h, z9.h\n" + "uzp1 z16.h, z10.h, z11.h\n" + "uzp1 z8.b, z8.b, z16.b\n" + "st1b { z8.b }, p2, [x14, #1, MUL VL]\n" + ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" + ".inst 0xa14042f1 // ld1w { z17.s, z25.s }, pn8.b/Z, [x23]\n" + "fmul z28.s, z28.s, z7.s\n" + "fmul z29.s, z29.s, z15.s\n" + "fmul z30.s, z30.s, z17.s\n" + "fmul z31.s, z31.s, z25.s\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1b2ce7c // sclamp { z28.s-z31.s }, z19.s, z18.s\n" + "uzp1 z28.h, z28.h, z29.h\n" + "uzp1 z16.h, z30.h, z31.h\n" + "uzp1 z28.b, z28.b, z16.b\n" + "st1b { z28.b }, p1, [x14, #2, MUL VL]\n" + "b 32f\n" + "25:" // Width 4 + "add x9, x16, x12, LSL #2\n" + "cntw x20, ALL, MUL #14\n" + ".inst 0xa040420c // ld1w { z12.s-z13.s }, pn8.b/Z, [x16]\n" + "add x28, x9, x12, LSL #1\n" + "add x27, x16, x12, LSL #1\n" + ".inst 0xa0404124 // ld1w { z4.s-z5.s }, pn8.b/Z, [x9]\n" + "add x26, x28, x12\n" + "cmp %x[N], x20\n" + ".inst 0xa0404368 // ld1w { z8.s-z9.s }, pn8.b/Z, [x27]\n" + "add x25, x16, x12\n" + "add x24, x27, x12\n" + ".inst 0xa0404380 // ld1w { z0.s-z1.s }, pn8.b/Z, [x28]\n" + "add x22, x9, x12\n" + "csel x26, x26, x16, GT\n" + ".inst 0xa040432e // ld1w { z14.s-z15.s }, pn8.b/Z, [x25]\n" + "mov x20, #0x3\n" + ".inst 0xa040430a // ld1w { z10.s-z11.s }, pn8.b/Z, [x24]\n" + "mov x11, %x[K]\n" + ".inst 0xa04042c6 // ld1w { z6.s-z7.s }, pn8.b/Z, [x22]\n" + "msub x21, x15, x20, %x[N]\n" + "mov x10, %x[A_ptr]\n" + ".inst 0xa0404342 // ld1w { z2.s-z3.s }, pn8.b/Z, [x26]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + ".inst 0xc0040d80 // mova za.d[x8, #0], { z12.d-z15.d }\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + ".inst 0xc0040d01 // mova za.d[x8, #1], { z8.d-z11.d }\n" + "add x23, x16, x12, LSL #3\n" + "addvl x16, x16, #2\n" + ".inst 0xc0040c82 // mova za.d[x8, #2], { z4.d-z7.d }\n" + "addvl x25, x25, #2\n" + "addvl x27, x27, #2\n" + ".inst 0xc0040c03 // mova za.d[x8, #3], { z0.d-z3.d }\n" + "addvl x24, x24, #2\n" + "addvl x9, x9, #2\n" + "addvl x22, x22, #2\n" + "addvl x28, x28, #2\n" + "addvl x26, x26, #2\n" + "ble 28f\n" + "26:" // Width 4: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa0400323 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400369 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa040030b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9020 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400395 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9121 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400357 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d90a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[0]\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400323 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xc15d92a3 // sdot za.s[x8, 3], { z20.b-z23.b }, z13.b[0]\n" + ".inst 0xa0400365 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400131 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9420 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[1]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400381 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d94a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[1]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400343 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d9622 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[1]\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400337 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xc15d9423 // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[1]\n" + ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400381 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9a21 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[2]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400343 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xc15d9823 // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[2]\n" + ".inst 0xa040037d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa040031f // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400135 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9ca0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002d7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9fa1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d9ea2 // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3]\n" + ".inst 0xc15d9ca3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 27f\n" + "sdot z25.s, z13.b, z27.b\n" + "27:" // Width 4: Multiply loop: unique 7: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 26b\n" + "28:" // Width 4: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa040013d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d90a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400389 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9221 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[0]\n" + "addvl x28, x28, #2\n" + ".inst 0xa040034b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d93a2 // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0]\n" + ".inst 0xc15d9123 // sdot za.s[x8, 3], { z8.b-z11.b }, z13.b[0]\n" + "ble 29f\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400121 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d94a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[1]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9621 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d9422 // sdot za.s[x8, 2], { z0.b-z3.b }, z13.b[1]\n" + ".inst 0xc15d94a3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[1]\n" + "ble 29f\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400337 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400369 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa040030b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400391 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9921 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[2]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400353 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xc15d9a23 // sdot za.s[x8, 3], { z16.b-z19.b }, z13.b[2]\n" + "ble 29f\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400361 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x27]\n" + ".inst 0xa0400303 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x24]\n" + ".inst 0xa0400131 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9ca0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3]\n" + ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" + ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9c21 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[3]\n" + ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9e22 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3]\n" + ".inst 0xc15d9ca3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3]\n" + "29:" // Width 4: Multiply loop: multiply skip + "tbnz %x[flags], #31, 30f\n" + "sdot z25.s, z13.b, z27.b\n" + "30:" // Width 4: Multiply loop: unique 8: skip row sum + ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" + ".inst 0xa0404202 // ld1w { z2.s-z3.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa040432c // ld1w { z12.s-z13.s }, pn8.b/Z, [x25]\n" + "add x20, %x[k_args], %[maxval]\n" + ".inst 0xc0060c3c // mova { z28.d-z31.d }, za.d[x8, #1]\n" + "add x16, x16, x12, LSL #1\n" + "ld1rw { z0.s }, p2/Z, [x22]\n" + "add x25, x25, x12, LSL #1\n" + ".inst 0xc0060c54 // mova { z20.d-z23.d }, za.d[x8, #2]\n" + "ld1rw { z1.s }, p2/Z, [x21]\n" + ".inst 0xc0060c68 // mova { z8.d-z11.d }, za.d[x8, #3]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "ld1rw { z17.s }, p2/Z, [x20]\n" + "fmul z4.s, z4.s, z2.s\n" + "fmul z5.s, z5.s, z3.s\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmul z6.s, z6.s, z12.s\n" + "fmul z7.s, z7.s, z13.s\n" + ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc1b1cc24 // sclamp { z4.s-z7.s }, z1.s, z17.s\n" + "uzp1 z4.h, z4.h, z5.h\n" + "uzp1 z16.h, z6.h, z7.h\n" + "uzp1 z4.b, z4.b, z16.b\n" + "st1b { z4.b }, p2, [x14]\n" + ".inst 0xa1404212 // ld1w { z18.s, z26.s }, pn8.b/Z, [x16]\n" + "add x16, x16, x12, LSL #1\n" + ".inst 0xa0404324 // ld1w { z4.s-z5.s }, pn8.b/Z, [x25]\n" + "add x25, x25, x12, LSL #1\n" + "fmul z28.s, z28.s, z18.s\n" + "fmul z29.s, z29.s, z26.s\n" + "fmul z30.s, z30.s, z4.s\n" + "fmul z31.s, z31.s, z5.s\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1b1cc3c // sclamp { z28.s-z31.s }, z1.s, z17.s\n" + "uzp1 z28.h, z28.h, z29.h\n" + "uzp1 z16.h, z30.h, z31.h\n" + "uzp1 z28.b, z28.b, z16.b\n" + "st1b { z28.b }, p2, [x14, #1, MUL VL]\n" + ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" + "add x16, x16, x12, LSL #1\n" + ".inst 0xa1404324 // ld1w { z4.s, z12.s }, pn8.b/Z, [x25]\n" + "add x25, x25, x12, LSL #1\n" + "fmul z20.s, z20.s, z7.s\n" + "fmul z21.s, z21.s, z15.s\n" + "fmul z22.s, z22.s, z4.s\n" + "fmul z23.s, z23.s, z12.s\n" + ".inst 0xc1b8e294 // frintn { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc131e294 // fcvtzs { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc1a0ab14 // add { z20.s-z23.s }, { z20.s-z23.s }, z0.s\n" + ".inst 0xc1b1cc34 // sclamp { z20.s-z23.s }, z1.s, z17.s\n" + "uzp1 z20.h, z20.h, z21.h\n" + "uzp1 z16.h, z22.h, z23.h\n" + "uzp1 z20.b, z20.b, z16.b\n" + "st1b { z20.b }, p2, [x14, #2, MUL VL]\n" + ".inst 0xa1404206 // ld1w { z6.s, z14.s }, pn8.b/Z, [x16]\n" + ".inst 0xa1404327 // ld1w { z7.s, z15.s }, pn8.b/Z, [x25]\n" + "fmul z8.s, z8.s, z6.s\n" + "fmul z9.s, z9.s, z14.s\n" + "fmul z10.s, z10.s, z7.s\n" + "fmul z11.s, z11.s, z15.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc1b1cc28 // sclamp { z8.s-z11.s }, z1.s, z17.s\n" + "uzp1 z8.h, z8.h, z9.h\n" + "uzp1 z16.h, z10.h, z11.h\n" + "uzp1 z8.b, z8.b, z16.b\n" + "st1b { z8.b }, p1, [x14, #3, MUL VL]\n" + "addvl x14, x14, #4\n" + "subs x13, x13, #0x4\n" + "mov x16, x23\n" + "sub %x[N], %x[N], x15, LSL #2\n" + "bgt 4b\n" + "32:" // Exit + ".inst 0xd503467f // SMSTOP\n" + : [N] "+&r"(N), [flags] "+&r"(flags) + : [A_ptr] "r"(A_ptr), [B_ptr] "r"(B_ptr), [K] "r"(K), [c_offset] "I"(offsetof(KernelArgs, c_offset)), + [k_args] "r"(&k_args), [maxval] "I"(offsetof(KernelArgs, maxval)), [minval] "I"(offsetof(KernelArgs, minval)), + [output_ptr] "r"(output_ptr) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", + "x27", "x28", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", + "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", + "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h new file mode 100644 index 0000000000000000000000000000000000000000..cb2932f9c44175773cd2b66a2ac2476b32a6fe25 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h @@ -0,0 +1,116 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// -# kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. This must be 0. +/// @param[in] k Columns of unpacked LHS. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of n_step +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be 0 +/// @param[in] n_idx Column index. Must be multiple of n_step +/// @param[in] dst_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot. +/// +/// @param[in] m Number of output rows to be computed. This must be 1. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. Unused parameter. +/// @param[in] rhs_packed Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] params Quantization parameters +void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..9b59dd6c9d5cde9f451a4efd515558c1e14d88f8 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h @@ -0,0 +1,50 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_qai8_qai8_qsi8cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_lhs_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_qai8_qai8_qsi8cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params); + +/// Micro-kernel interface +struct kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel { + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_nr_func_t get_nr; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_kr_func_t get_kr; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_sr_func_t get_sr; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_lhs_offset_func_t get_lhs_offset; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_qai8_qai8_qsi8cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 81ffaf4d555e7c5396f7409c313ed6e0b931d4a4..252dd2fbe76ba4447c19cd7fc00a73f4ac5c394f 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -13,7 +13,8 @@ #include #include -#include "kai/kai_common.h" +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" @@ -91,7 +92,8 @@ const static RhsPackKernel rhs_pack = { struct MatMulVariant { std::string_view name; ///< Test identification - MatMulShape acc; ///< Accumulator shape + MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) + MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) std::function is_supported; ///< HW support check @@ -103,7 +105,12 @@ struct MatMulVariant { const std::array gemm_variants = { MatMulVariant{ .name = "matmul_qai8_qai8p_qsi8cxp", - .acc{ + .acc_pack{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + .acc_step{ .m = 2 * get_sme_vector_length(), .n = 2 * get_sme_vector_length(), .k = sizeof(int32_t) / sizeof(int8_t), @@ -139,6 +146,45 @@ const std::array gemm_variants = { }, }; +const std::array gemv_variants = { + MatMulVariant{ + .name = "matmul_qai8_qai8_qsi8cxp", + .acc_pack{ + .m = 1, + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + .acc_step{ + .m = 1, + .n = 16 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + + .is_supported = cpu_has_sme2, + + .lhs_pack = std::nullopt, + .rhs_pack = rhs_pack, + .matmul = MatMulKernel{ + .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_mr = []() -> size_t { return 1; }, + .get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_packed_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .matmul = + [](size_t m, size_t n, size_t k, const void* lhs, const void* rhs, void* dst, size_t dst_stride_row, + size_t dst_stride_col, const kai_matmul_requantize32_params* quant_param) { + kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + m, n, k, lhs, sizeof(int8_t), rhs, dst, dst_stride_row, dst_stride_col, quant_param); + }, + }, + }, +}; + constexpr uint64_t seed = 0; ///< Random seed used for tests constexpr float output_clamp_rate = 0.1F; ///< Clamping range in ration of output @@ -181,6 +227,21 @@ struct TestReference { Buffer packed_rhs; }; +/// Make sure that interface matches +static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface + [[maybe_unused]] = { + .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, +}; + /// Generate test reference data static TestReference get_test_reference(const MatMulShape& shape, const MatMulVariant& variant) { // ============================================================ @@ -196,27 +257,39 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa // * LHS: 8-bit asymmetric per-matrix quantization. // * RHS: 8-bit symmetric per-channel quantization. // * Bias: 32-bit symmetric per-channel quantization. + // + // Treat entire LHS as one row vector, to calculate one single pair of auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] = quantize_asymmetric_per_block_dynamic( lhs_f32.data(), 1, shape.m * shape.k, shape.m * shape.k); const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); + // Transpose, then quantize symmetrically, then transpose back. This will give one + // quantization value for each column const auto rhs_f32_t = transpose(rhs_f32.data(), shape.k, shape.n); auto [rhs_qsi8_t, rhs_scales] = quantize_symmetric_per_block_dynamic(rhs_f32_t.data(), shape.n, shape.k, shape.k); auto rhs_qsi8 = transpose(rhs_qsi8_t.data(), shape.n, shape.k); + // Multiply all bias values with the LHS scale const auto bias_scales = mul(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n); + // Calculate quantized bias values, by treating bias as column, and + // scale using RHS scales. This will scale each bias value indiviually auto bias_qsi32 = quantize_symmetric_per_block(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); // Runs the reference implementation of matmul to produce floating-point result. const auto ref_dst_f32 = matmul_nt_t_quantized( - shape.m, shape.n, shape.k, lhs_qai8.data(), &lhs_scale, &lhs_zero_point, shape.m, shape.k, - rhs_qsi8_t.data(), rhs_scales.data(), nullptr, 1, shape.k, bias_qsi32.data(), bias_scales.data(), nullptr, - 1); + shape.m, shape.n, shape.k, // matmul shape + lhs_qai8.data(), &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point + shape.m, shape.k, // LHS quantization window shape + rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors + 1, shape.k, // RHS quantization window shape + bias_qsi32.data(), bias_scales.data(), nullptr, // Bias, scaling and zero points + 1 // Bias quantization window shape + ); // Computes the output quantization information and clamping limits. // @@ -247,17 +320,20 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa const auto ref_dst_f32_clamped = clamp(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max); auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block( - ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, 1, shape.m * shape.n, shape.m * shape.n); + ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, // values, scales, zero point + 1, shape.m * shape.n, // data shape + shape.m * shape.n // quantization window width + ); // Runs the reference implementation of the packing functions. // // The reference packing functions cannot be executed earlier // because we need the reference floating-point output first to have // the quantization information. - auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc.m, variant.acc.k); + auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k); auto packed_rhs = matmul_pack_rhs_nxk_static_quantized( rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, - variant.acc.n, variant.acc.k); + variant.acc_pack.n, variant.acc_pack.k); return { .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, @@ -287,20 +363,22 @@ static void test_lhs_pack( KAI_ASSUME(variant.lhs_pack.has_value()); const auto imp_packed_lhs_size = - variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc.m, variant.acc.k, 1); + variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size()); Buffer imp_packed_lhs(imp_packed_lhs_size); const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); - const auto imp_packed_lhs_offset = - variant.lhs_pack->get_packed_lhs_offset(output_area.start_row(), shape.k, variant.acc.m, variant.acc.k, 1); + const auto imp_packed_lhs_offset = variant.lhs_pack->get_packed_lhs_offset( + output_area.start_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); variant.lhs_pack->pack( - output_area.height(), shape.k, variant.acc.m, variant.acc.k, 1, 0, reference.lhs_qai8.data() + imp_lhs_offset, - shape.k * sizeof(int8_t), imp_packed_lhs.data() + imp_packed_lhs_offset); + output_area.height(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1, 0, + reference.lhs_qai8.data() + imp_lhs_offset, shape.k * sizeof(int8_t), + imp_packed_lhs.data() + imp_packed_lhs_offset); const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m - ? variant.lhs_pack->get_packed_lhs_offset(output_area.end_row(), shape.k, variant.acc.m, variant.acc.k, 1) + ? variant.lhs_pack->get_packed_lhs_offset( + output_area.end_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1) : imp_packed_lhs_size; for (size_t i = 0; i < reference.packed_lhs.size(); ++i) { @@ -330,7 +408,7 @@ static void test_rhs_pack( }; variant.rhs_pack.pack( - 1, output_area.width(), shape.k, variant.acc.n, variant.acc.k, 1, shape.n * sizeof(int8_t), + 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1, shape.n * sizeof(int8_t), reference.rhs_qsi8.data() + imp_rhs_offset, reference.bias_qsi32.data() + imp_bias_offset, reference.rhs_scales.data() + imp_scale_offset, imp_packed_rhs.data() + imp_packed_rhs_offset, 0, &imp_pack_rhs_params); @@ -339,16 +417,22 @@ static void test_rhs_pack( ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k) : imp_packed_rhs_size; + size_t mismatches = 0; for (size_t i = 0; i < reference.packed_rhs.size(); ++i) { if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { - ASSERT_EQ(imp_packed_rhs[i], reference.packed_rhs[i]); + if (imp_packed_rhs[i] != reference.packed_rhs[i]) { + mismatches += 1; + } } else { - ASSERT_EQ(imp_packed_rhs[i], 0); + if (imp_packed_rhs[i] != 0) { + mismatches += 1; + } } } + ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; } -/// Test MatMul of GEMM like kernel +/// Test MatMul of GEMM/GEMV like kernel static void test_matmul( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n); @@ -377,6 +461,7 @@ static void test_matmul( reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), sizeof(int8_t), &imp_main_params); + size_t mismatches = 0; for (size_t y = 0; y < shape.m; ++y) { for (size_t x = 0; x < shape.n; ++x) { const auto i = y * shape.n + x; @@ -388,11 +473,10 @@ static void test_matmul( const auto error = std::abs(imp_value - ref_value); const auto threshold = in_area ? 1 : 0; - if (error > threshold) { - ASSERT_EQ(imp_value, ref_value); - } + mismatches += static_cast(error > threshold); } } + ASSERT_EQ(mismatches, 0) << "There are mismatched between reference result actual result"; } using ThisTest = testing::TestWithParam>; @@ -426,24 +510,26 @@ TEST_P(ThisTest, EndToEnd) { const auto imp_kr = variant.matmul.get_kr(); const auto imp_sr = variant.matmul.get_sr(); - ASSERT_EQ(imp_mr, variant.acc.m); - ASSERT_EQ(imp_nr, variant.acc.n); - ASSERT_EQ(imp_kr, variant.acc.k); + ASSERT_EQ(imp_mr, variant.acc_pack.m); + ASSERT_EQ(imp_nr, variant.acc_pack.n); + ASSERT_EQ(imp_kr, variant.acc_pack.k); ASSERT_EQ(imp_sr, 1); + // Check that stepping is a multiple of accumulation const auto imp_m_step = variant.matmul.get_m_step(); const auto imp_n_step = variant.matmul.get_n_step(); + ASSERT_EQ(imp_m_step, variant.acc_step.m); + ASSERT_EQ(imp_n_step, variant.acc_step.n); - ASSERT_EQ(imp_m_step, variant.acc.m); - ASSERT_EQ(imp_n_step, variant.acc.n); - - // Test kernels - const auto output_area = output_portion.compute_portion(shape.m, shape.n, variant.acc.m, variant.acc.n); + // Test kernels. Note that packing and actual stepping might not be the same + const auto pack_portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_pack.m, variant.acc_pack.n); + const auto matmul_portion = + output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); if (variant.lhs_pack.has_value()) { - test_lhs_pack(shape, variant, output_area, reference); + test_lhs_pack(shape, variant, pack_portion, reference); } - test_rhs_pack(shape, variant, output_area, reference); - test_matmul(shape, variant, output_area, reference); + test_rhs_pack(shape, variant, pack_portion, reference); + test_matmul(shape, variant, matmul_portion, reference); } INSTANTIATE_TEST_SUITE_P( @@ -451,21 +537,23 @@ INSTANTIATE_TEST_SUITE_P( testing::Combine( testing::ValuesIn(gemm_variants), testing::ValuesIn({ - MatMulShape{1, 1, 1}, // - MatMulShape{ - 2 * get_sme_vector_length(), 2 * get_sme_vector_length(), - sizeof(int32_t) / sizeof(int8_t)}, // - MatMulShape{20, 30, 40}, // - MatMulShape{1, 49, 21}, // - MatMulShape{23, 1, 43}, // - MatMulShape{32, 14, 1}, // - MatMulShape{123, 85, 45}, // - MatMulShape{130, 130, 6}, + // clang-format off + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 20, 30, 40}, + MatMulShape{ 23, 1, 43}, + MatMulShape{ 32, 14, 1}, + MatMulShape{ 64, 64, 4}, + MatMulShape{123, 85, 45}, + MatMulShape{130, 130, 6}, + // clang-format on }), testing::ValuesIn({ - MatrixPortion(0, 0, 1, 1), // Full matrix. - MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. - MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. + // clang-format off + MatrixPortion( 0, 0, 1, 1), // Full matrix. + MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + // clang-format on })), [](const auto& info) -> std::string { return test_description( @@ -474,4 +562,37 @@ INSTANTIATE_TEST_SUITE_P( std::get(info.param)); }); +INSTANTIATE_TEST_SUITE_P( + matmul_clamp_qai8_qai8_qsi8cxp, ThisTest, + testing::Combine( + testing::ValuesIn(gemv_variants), + testing::ValuesIn({ + // clang-format off + MatMulShape{1, 1, 1}, + MatMulShape{1, 16, 4}, + MatMulShape{1, 16, 16}, + MatMulShape{1, 17, 4}, + MatMulShape{1, 32, 32}, + MatMulShape{1, 33, 200}, + MatMulShape{1, 64, 4}, + MatMulShape{1, 65, 4}, + MatMulShape{1, 300, 10}, + MatMulShape{1, 512, 4}, + MatMulShape{1, 1523, 10}, + // clang-format on + }), + testing::ValuesIn({ + // clang-format off + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, .5, 1, .5), // Right half + MatrixPortion(0, 0, 1, .5), // Left half + MatrixPortion(0, .25, 1, .5) // Middle half + // clang-format on + })), + [](const auto& info) -> std::string { + return test_description( + std::get(info.param), // + std::get(info.param), // + std::get(info.param)); + }); } // namespace kai::test