From fab6fce4abd7bcab622d78dd3126c87da5ad7d2b Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 27 Jan 2025 14:41:53 +0100 Subject: [PATCH 1/4] Add SME2 kernel for QAI8 GEMV Add support for GEMV like kernel for producing QAI8 from QAI8 LHS and QSI8CXP packed RHS. Update unit tests to include support for new kernel Signed-off-by: Emil Ohlsson --- CHANGELOG.md | 2 + CMakeLists.txt | 9 +- kai/ukernels/matmul/BUILD.bazel | 1 + ...qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c | 961 ++++++++++++++++++ ...qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h | 116 +++ ...matmul_clamp_qai8_qai8_qsi8cxp_interface.h | 50 + .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 212 +++- 7 files changed, 1302 insertions(+), 49 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c create mode 100644 kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h create mode 100644 kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bfa81b6..fb3e8f44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Optimizations for FEAT_DotProd. - New 1x8 block size variant of matrix multiplication of QAI8DXP LHS and QSI4C32P RHS with F32 output. - Optimizations for FEAT_DotProd. +- New SME2 micro-kernels: + - Matrix multiplication (1xN) of QAI8 LHS and QSI8 RHS to produce QAI8 output. - Added demonstration of integration using CMake in F16 Arm® Neon™ matrix multiplication example. - Fixes: - Fix the RHS packing micro-kernel kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon to handle null bias. diff --git a/CMakeLists.txt b/CMakeLists.txt index 19a7c863..b9009b0c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,17 +188,18 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2 + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 54b2b785..4fa64cc9 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -144,6 +144,7 @@ SME2_KERNELS = [ "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", + "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c new file mode 100644 index 00000000..6fcb8108 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c @@ -0,0 +1,961 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. + +#include "kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_nr = 2; +static const size_t kai_n_step = 16; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +size_t kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; +} + +size_t kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) { + return kai_sr; +} + +size_t kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx == 0); + + return m_idx * k; +} + +static size_t kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t k) { + return kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() * + (kai_roundup(k, kai_kr) * sizeof(int8_t) + sizeof(int32_t) + sizeof(int32_t)); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() == 0); + + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(k); +} + +size_t kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() == 0); + + return (m_idx * dst_stride) + (n_idx * sizeof(int8_t)); +} + +size_t kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m, size_t n) { + return m * n * sizeof(int8_t); +} + +void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params) { + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(lhs_stride); + KAI_ASSUME(m == 1); + + typedef struct { + int32_t b_offset; + int32_t c_offset; + int32_t maxval; + int32_t minval; + } KernelArgs; + + KernelArgs k_args; + k_args.maxval = params->max_value; + k_args.minval = params->min_value; + k_args.c_offset = params->output_zero_point; + + size_t N = n; + size_t K = k; + + const void* A_ptr = lhs; + const void* B_ptr = rhs_packed; + void* output_ptr = dst; + + uint64_t flags = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "mov x8, #0x0\n" + "mov x16, %x[B_ptr]\n" + "cntw x15, ALL, MUL #4\n" + "mov x14, %x[output_ptr]\n" + "add x13, %x[N], x15\n" + "ptrue p2.b\n" + "sub x13, x13, #0x1\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "udiv x13, x13, x15\n" + "mov x22, #0x1\n" + "add x21, x13, #0x3\n" + "and x21, x21, #0xfffffffffffffffc\n" + "mul x21, x21, x15\n" + "mul x21, x21, %x[K]\n" + "1:" // RHS size check loop + "cmp x21, #0x200000\n" + "blt 2f\n" + "tbnz x21, #0, 3f\n" + "lsr x21, x21, #0x1\n" + "lsl x22, x22, #0x1\n" + "b 1b\n" + "2:" // RHS do prefetch + "lsl x20, x21, #0x26\n" + "sub x22, x22, #0x1\n" + "lsl x22, x22, #0x16\n" + "orr x21, x21, x20\n" + "orr x21, x21, x22\n" + ".inst 0xf8b54a1a // rprfm pldonce, x21, [x16]\n" + "3:" // RHS prefetch exit + "add x12, %x[K], #0x3\n" + "cntw x20, ALL, MUL #2\n" + "mov z25.s, #0x0\n" + "mov z27.b, #0x1\n" + "bic x12, x12, #0x3\n" + "bic %x[flags], %x[flags], #0x80000000\n" + "add x12, x12, #0x8\n" + "mul x12, x12, x20\n" + "4:" // Column loop + "cmp x13, #0x4\n" + "bge 25f\n" + "cmp x13, #0x2\n" + "bgt 18f\n" + "beq 11f\n" + "cntw x20, ALL, MUL #2\n" + "add x23, x16, x12\n" + ".inst 0xa0404210 // ld1w { z16.s-z17.s }, pn8.b/Z, [x16]\n" + "cmp %x[N], x20\n" + "mov x11, %x[K]\n" + "csel x23, x23, x16, GT\n" + "mov x21, %x[N]\n" + ".inst 0xa04042f2 // ld1w { z18.s-z19.s }, pn8.b/Z, [x23]\n" + "mov x10, %x[A_ptr]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + "addvl x16, x16, #2\n" + "addvl x23, x23, #2\n" + ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" + "ble 7f\n" + "5:" // Width 1: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d93a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[0]\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d96a0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1]\n" + ".inst 0xc15d9a20 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[2]\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 6f\n" + "sdot z25.s, z13.b, z27.b\n" + "6:" // Width 1: Multiply loop: unique 1: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 5b\n" + "7:" // Width 1: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002e7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d90a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0]\n" + "ble 8f\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d97a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1]\n" + "ble 8f\n" + ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9920 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2]\n" + "ble 8f\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + "8:" // Width 1: Multiply loop: multiply skip + "tbnz %x[flags], #31, 9f\n" + "9:" // Width 1: Multiply loop: unique 2: skip row sum + ".inst 0xc0060c08 // mova { z8.d-z11.d }, za.d[x8, #0]\n" + ".inst 0xa040421e // ld1w { z30.s-z31.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa04042f8 // ld1w { z24.s-z25.s }, pn8.b/Z, [x23]\n" + "add x20, %x[k_args], %[maxval]\n" + "ld1rw { z2.s }, p2/Z, [x22]\n" + "ld1rw { z13.s }, p2/Z, [x21]\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + "ld1rw { z20.s }, p2/Z, [x20]\n" + "fmul z8.s, z8.s, z30.s\n" + "fmul z9.s, z9.s, z31.s\n" + "fmul z10.s, z10.s, z24.s\n" + "fmul z11.s, z11.s, z25.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1a2ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z2.s\n" + ".inst 0xc1b4cda8 // sclamp { z8.s-z11.s }, z13.s, z20.s\n" + "uzp1 z8.h, z8.h, z9.h\n" + "uzp1 z0.h, z10.h, z11.h\n" + "uzp1 z8.b, z8.b, z0.b\n" + "st1b { z8.b }, p1, [x14]\n" + "b 32f\n" + "11:" // Width 2 + "add x24, x16, x12, LSL #1\n" + "cntw x20, ALL, MUL #6\n" + ".inst 0xa0404214 // ld1w { z20.s-z21.s }, pn8.b/Z, [x16]\n" + "add x22, x24, x12\n" + "cmp %x[N], x20\n" + ".inst 0xa040430c // ld1w { z12.s-z13.s }, pn8.b/Z, [x24]\n" + "add x23, x16, x12\n" + "csel x22, x22, x16, GT\n" + ".inst 0xa04042f6 // ld1w { z22.s-z23.s }, pn8.b/Z, [x23]\n" + "mov x11, %x[K]\n" + "sub x21, %x[N], x15\n" + ".inst 0xa04042ce // ld1w { z14.s-z15.s }, pn8.b/Z, [x22]\n" + "mov x10, %x[A_ptr]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xc0040e80 // mova za.d[x8, #0], { z20.d-z23.d }\n" + "addvl x23, x23, #2\n" + "addvl x24, x24, #2\n" + ".inst 0xc0040d81 // mova za.d[x8, #1], { z12.d-z15.d }\n" + "addvl x22, x22, #2\n" + "ble 14f\n" + "12:" // Width 2: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400305 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" + ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d90a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0]\n" + ".inst 0xa0400311 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9520 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[1]\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9621 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1]\n" + ".inst 0xa0400305 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9820 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2]\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d98a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xa0400309 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + ".inst 0xc15d9d21 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 13f\n" + "sdot z25.s, z13.b, z27.b\n" + "13:" // Width 2: Multiply loop: unique 3: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 12b\n" + "14:" // Width 2: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400301 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" + ".inst 0xc15d9021 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[0]\n" + "ble 15f\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040031d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d96a0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[1]\n" + ".inst 0xc15d97a1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[1]\n" + "ble 15f\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400319 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa04002db // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xc15d9820 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[2]\n" + ".inst 0xc15d9b21 // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[2]\n" + "ble 15f\n" + ".inst 0xa0400219 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002fb // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040031d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x24]\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + ".inst 0xc15d9f20 // sdot za.s[x8, 0], { z24.b-z27.b }, z13.b[3]\n" + ".inst 0xc15d9fa1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3]\n" + "15:" // Width 2: Multiply loop: multiply skip + "tbnz %x[flags], #31, 16f\n" + "16:" // Width 2: Multiply loop: unique 4: skip row sum + ".inst 0xc0060c00 // mova { z0.d-z3.d }, za.d[x8, #0]\n" + ".inst 0xa0404208 // ld1w { z8.s-z9.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa04042fe // ld1w { z30.s-z31.s }, pn8.b/Z, [x23]\n" + "add x20, %x[k_args], %[maxval]\n" + ".inst 0xc0060c24 // mova { z4.d-z7.d }, za.d[x8, #1]\n" + "add x16, x16, x12, LSL #1\n" + "ld1rw { z14.s }, p2/Z, [x22]\n" + "add x23, x23, x12, LSL #1\n" + "ld1rw { z11.s }, p2/Z, [x21]\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + "ld1rw { z10.s }, p2/Z, [x20]\n" + "fmul z0.s, z0.s, z8.s\n" + "fmul z1.s, z1.s, z9.s\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "fmul z2.s, z2.s, z30.s\n" + "fmul z3.s, z3.s, z31.s\n" + ".inst 0xc1b8e000 // frintn { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc131e000 // fcvtzs { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc1aeab00 // add { z0.s-z3.s }, { z0.s-z3.s }, z14.s\n" + ".inst 0xc1aacd60 // sclamp { z0.s-z3.s }, z11.s, z10.s\n" + "uzp1 z0.h, z0.h, z1.h\n" + "uzp1 z16.h, z2.h, z3.h\n" + "uzp1 z0.b, z0.b, z16.b\n" + "st1b { z0.b }, p2, [x14]\n" + ".inst 0xa1404217 // ld1w { z23.s, z31.s }, pn8.b/Z, [x16]\n" + ".inst 0xa14042f6 // ld1w { z22.s, z30.s }, pn8.b/Z, [x23]\n" + "fmul z4.s, z4.s, z23.s\n" + "fmul z5.s, z5.s, z31.s\n" + "fmul z6.s, z6.s, z22.s\n" + "fmul z7.s, z7.s, z30.s\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1aeab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z14.s\n" + ".inst 0xc1aacd64 // sclamp { z4.s-z7.s }, z11.s, z10.s\n" + "uzp1 z4.h, z4.h, z5.h\n" + "uzp1 z2.h, z6.h, z7.h\n" + "uzp1 z4.b, z4.b, z2.b\n" + "st1b { z4.b }, p1, [x14, #1, MUL VL]\n" + "b 32f\n" + "18:" // Width 3 + "add x26, x16, x12, LSL #2\n" + "cntw x20, ALL, MUL #10\n" + ".inst 0xa0404210 // ld1w { z16.s-z17.s }, pn8.b/Z, [x16]\n" + "add x25, x16, x12, LSL #1\n" + "add x24, x26, x12\n" + ".inst 0xa040435c // ld1w { z28.s-z29.s }, pn8.b/Z, [x26]\n" + "cmp %x[N], x20\n" + "add x23, x16, x12\n" + ".inst 0xa040432c // ld1w { z12.s-z13.s }, pn8.b/Z, [x25]\n" + "add x22, x25, x12\n" + "csel x24, x24, x16, GT\n" + ".inst 0xa04042f2 // ld1w { z18.s-z19.s }, pn8.b/Z, [x23]\n" + "mov x20, #0x2\n" + ".inst 0xa04042ce // ld1w { z14.s-z15.s }, pn8.b/Z, [x22]\n" + "mov x11, %x[K]\n" + ".inst 0xa040431e // ld1w { z30.s-z31.s }, pn8.b/Z, [x24]\n" + "msub x21, x15, x20, %x[N]\n" + "mov x10, %x[A_ptr]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + ".inst 0xc0040e00 // mova za.d[x8, #0], { z16.d-z19.d }\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + ".inst 0xc0040d81 // mova za.d[x8, #1], { z12.d-z15.d }\n" + "addvl x16, x16, #2\n" + "addvl x23, x23, #2\n" + ".inst 0xc0040f82 // mova za.d[x8, #2], { z28.d-z31.d }\n" + "addvl x25, x25, #2\n" + "addvl x22, x22, #2\n" + "addvl x26, x26, #2\n" + "addvl x24, x24, #2\n" + "ble 21f\n" + "19:" // Width 3: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa04002e3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400329 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400351 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9020 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9121 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0]\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d9222 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[0]\n" + ".inst 0xa0400321 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d97a0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[1]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9421 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[1]\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d94a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1]\n" + ".inst 0xa040033d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9ba1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2]\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xa0400335 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002d7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400351 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9ea1 // sdot za.s[x8, 1], { z20.b-z23.b }, z13.b[3]\n" + ".inst 0xc15d9e22 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 20f\n" + "sdot z25.s, z13.b, z27.b\n" + "20:" // Width 3: Multiply loop: unique 5: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 19b\n" + "21:" // Width 3: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400325 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa040035d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9220 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[0]\n" + "addvl x26, x26, #2\n" + ".inst 0xa040031f // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d90a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[0]\n" + ".inst 0xc15d93a2 // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0]\n" + "ble 22f\n" + ".inst 0xa0400211 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002f3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400329 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002cb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400345 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9620 // sdot za.s[x8, 0], { z16.b-z19.b }, z13.b[1]\n" + "addvl x26, x26, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9521 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[1]\n" + ".inst 0xc15d94a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[1]\n" + "ble 22f\n" + ".inst 0xa0400209 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002eb // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa040033d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400359 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9920 // sdot za.s[x8, 0], { z8.b-z11.b }, z13.b[2]\n" + "addvl x26, x26, #2\n" + ".inst 0xa040031b // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xc15d9ba1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[2]\n" + ".inst 0xc15d9b22 // sdot za.s[x8, 2], { z24.b-z27.b }, z13.b[2]\n" + "ble 22f\n" + ".inst 0xa040021d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa04002ff // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa0400339 // ldnt1b { z24.b-z25.b }, pn8.b/Z, [x25]\n" + ".inst 0xa04002db // ldnt1b { z26.b-z27.b }, pn8.b/Z, [x22]\n" + ".inst 0xa0400355 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9fa0 // sdot za.s[x8, 0], { z28.b-z31.b }, z13.b[3]\n" + ".inst 0xa0400317 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x24]\n" + ".inst 0xc15d9f21 // sdot za.s[x8, 1], { z24.b-z27.b }, z13.b[3]\n" + ".inst 0xc15d9ea2 // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3]\n" + "22:" // Width 3: Multiply loop: multiply skip + "tbnz %x[flags], #31, 23f\n" + "23:" // Width 3: Multiply loop: unique 6: skip row sum + ".inst 0xc0060c18 // mova { z24.d-z27.d }, za.d[x8, #0]\n" + ".inst 0xa0404202 // ld1w { z2.s-z3.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa04042e6 // ld1w { z6.s-z7.s }, pn8.b/Z, [x23]\n" + "add x20, %x[k_args], %[maxval]\n" + ".inst 0xc0060c28 // mova { z8.d-z11.d }, za.d[x8, #1]\n" + "add x16, x16, x12, LSL #1\n" + "ld1rw { z0.s }, p2/Z, [x22]\n" + "add x23, x23, x12, LSL #1\n" + ".inst 0xc0060c5c // mova { z28.d-z31.d }, za.d[x8, #2]\n" + "ld1rw { z19.s }, p2/Z, [x21]\n" + ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" + "ld1rw { z18.s }, p2/Z, [x20]\n" + "fmul z24.s, z24.s, z2.s\n" + "fmul z25.s, z25.s, z3.s\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + "fmul z26.s, z26.s, z6.s\n" + "fmul z27.s, z27.s, z7.s\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1b8e318 // frintn { z24.s-z27.s }, { z24.s-z27.s }\n" + ".inst 0xc131e318 // fcvtzs { z24.s-z27.s }, { z24.s-z27.s }\n" + ".inst 0xc1a0ab18 // add { z24.s-z27.s }, { z24.s-z27.s }, z0.s\n" + ".inst 0xc1b2ce78 // sclamp { z24.s-z27.s }, z19.s, z18.s\n" + "uzp1 z24.h, z24.h, z25.h\n" + "uzp1 z16.h, z26.h, z27.h\n" + "uzp1 z24.b, z24.b, z16.b\n" + "st1b { z24.b }, p2, [x14]\n" + ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" + "add x16, x16, x12, LSL #1\n" + ".inst 0xa14042f1 // ld1w { z17.s, z25.s }, pn8.b/Z, [x23]\n" + "add x23, x23, x12, LSL #1\n" + "fmul z8.s, z8.s, z7.s\n" + "fmul z9.s, z9.s, z15.s\n" + "fmul z10.s, z10.s, z17.s\n" + "fmul z11.s, z11.s, z25.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc1b2ce68 // sclamp { z8.s-z11.s }, z19.s, z18.s\n" + "uzp1 z8.h, z8.h, z9.h\n" + "uzp1 z16.h, z10.h, z11.h\n" + "uzp1 z8.b, z8.b, z16.b\n" + "st1b { z8.b }, p2, [x14, #1, MUL VL]\n" + ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" + ".inst 0xa14042f1 // ld1w { z17.s, z25.s }, pn8.b/Z, [x23]\n" + "fmul z28.s, z28.s, z7.s\n" + "fmul z29.s, z29.s, z15.s\n" + "fmul z30.s, z30.s, z17.s\n" + "fmul z31.s, z31.s, z25.s\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1b2ce7c // sclamp { z28.s-z31.s }, z19.s, z18.s\n" + "uzp1 z28.h, z28.h, z29.h\n" + "uzp1 z16.h, z30.h, z31.h\n" + "uzp1 z28.b, z28.b, z16.b\n" + "st1b { z28.b }, p1, [x14, #2, MUL VL]\n" + "b 32f\n" + "25:" // Width 4 + "add x9, x16, x12, LSL #2\n" + "cntw x20, ALL, MUL #14\n" + ".inst 0xa040420c // ld1w { z12.s-z13.s }, pn8.b/Z, [x16]\n" + "add x28, x9, x12, LSL #1\n" + "add x27, x16, x12, LSL #1\n" + ".inst 0xa0404124 // ld1w { z4.s-z5.s }, pn8.b/Z, [x9]\n" + "add x26, x28, x12\n" + "cmp %x[N], x20\n" + ".inst 0xa0404368 // ld1w { z8.s-z9.s }, pn8.b/Z, [x27]\n" + "add x25, x16, x12\n" + "add x24, x27, x12\n" + ".inst 0xa0404380 // ld1w { z0.s-z1.s }, pn8.b/Z, [x28]\n" + "add x22, x9, x12\n" + "csel x26, x26, x16, GT\n" + ".inst 0xa040432e // ld1w { z14.s-z15.s }, pn8.b/Z, [x25]\n" + "mov x20, #0x3\n" + ".inst 0xa040430a // ld1w { z10.s-z11.s }, pn8.b/Z, [x24]\n" + "mov x11, %x[K]\n" + ".inst 0xa04042c6 // ld1w { z6.s-z7.s }, pn8.b/Z, [x22]\n" + "msub x21, x15, x20, %x[N]\n" + "mov x10, %x[A_ptr]\n" + ".inst 0xa0404342 // ld1w { z2.s-z3.s }, pn8.b/Z, [x26]\n" + "mov x20, %x[K]\n" + "whilelt p1.b, XZR, x21\n" + ".inst 0xc0040d80 // mova za.d[x8, #0], { z12.d-z15.d }\n" + "cmp x11, #0x10\n" + ".inst 0xf8b44958 // rprfm pldmany, x20, [x10]\n" + ".inst 0xc0040d01 // mova za.d[x8, #1], { z8.d-z11.d }\n" + "add x23, x16, x12, LSL #3\n" + "addvl x16, x16, #2\n" + ".inst 0xc0040c82 // mova za.d[x8, #2], { z4.d-z7.d }\n" + "addvl x25, x25, #2\n" + "addvl x27, x27, #2\n" + ".inst 0xc0040c03 // mova za.d[x8, #3], { z0.d-z3.d }\n" + "addvl x24, x24, #2\n" + "addvl x9, x9, #2\n" + "addvl x22, x22, #2\n" + "addvl x28, x28, #2\n" + "addvl x26, x26, #2\n" + "ble 28f\n" + "26:" // Width 4: Multiply loop: Main loop head + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "add x10, x10, #0x10\n" + ".inst 0xa0400323 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400369 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa040030b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9020 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[0]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400395 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9121 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[0]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400357 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d90a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[0]\n" + ".inst 0xa0400201 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400323 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xc15d92a3 // sdot za.s[x8, 3], { z20.b-z23.b }, z13.b[0]\n" + ".inst 0xa0400365 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400307 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400131 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9420 // sdot za.s[x8, 0], { z0.b-z3.b }, z13.b[1]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400381 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d94a1 // sdot za.s[x8, 1], { z4.b-z7.b }, z13.b[1]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400343 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d9622 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[1]\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400337 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xc15d9423 // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[1]\n" + ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400381 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9a21 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[2]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400343 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xc15d9823 // sdot za.s[x8, 3], { z0.b-z3.b }, z13.b[2]\n" + ".inst 0xa040037d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa040031f // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400135 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9ca0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002d7 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9fa1 // sdot za.s[x8, 1], { z28.b-z31.b }, z13.b[3]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d9ea2 // sdot za.s[x8, 2], { z20.b-z23.b }, z13.b[3]\n" + ".inst 0xc15d9ca3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3]\n" + "tbnz %x[flags], #31, 27f\n" + "sdot z25.s, z13.b, z27.b\n" + "27:" // Width 4: Multiply loop: unique 7: skip row sum + "sub x11, x11, #0x10\n" + "cmp x11, #0x10\n" + "bgt 26b\n" + "28:" // Width 4: Multiply loop: Single iteration only + "whilelt p0.b, XZR, x11\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "ld1rqb { z13.b }, p0/Z, [x10]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa040013d // ldnt1b { z28.b-z29.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d90a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[0]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002df // ldnt1b { z30.b-z31.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400389 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9221 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[0]\n" + "addvl x28, x28, #2\n" + ".inst 0xa040034b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d93a2 // sdot za.s[x8, 2], { z28.b-z31.b }, z13.b[0]\n" + ".inst 0xc15d9123 // sdot za.s[x8, 3], { z8.b-z11.b }, z13.b[0]\n" + "ble 29f\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400371 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa0400313 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400121 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d94a0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[1]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c3 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9621 // sdot za.s[x8, 1], { z16.b-z19.b }, z13.b[1]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d9422 // sdot za.s[x8, 2], { z0.b-z3.b }, z13.b[1]\n" + ".inst 0xc15d94a3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[1]\n" + "ble 29f\n" + ".inst 0xa0400215 // ldnt1b { z20.b-z21.b }, pn8.b/Z, [x16]\n" + "subs x11, x11, #0x4\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400337 // ldnt1b { z22.b-z23.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400369 // ldnt1b { z8.b-z9.b }, pn8.b/Z, [x27]\n" + "addvl x27, x27, #2\n" + ".inst 0xa040030b // ldnt1b { z10.b-z11.b }, pn8.b/Z, [x24]\n" + "addvl x24, x24, #2\n" + ".inst 0xa0400125 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9aa0 // sdot za.s[x8, 0], { z20.b-z23.b }, z13.b[2]\n" + "addvl x9, x9, #2\n" + ".inst 0xa04002c7 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x22]\n" + "addvl x22, x22, #2\n" + ".inst 0xa0400391 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9921 // sdot za.s[x8, 1], { z8.b-z11.b }, z13.b[2]\n" + "addvl x28, x28, #2\n" + ".inst 0xa0400353 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x26]\n" + "addvl x26, x26, #2\n" + ".inst 0xc15d98a2 // sdot za.s[x8, 2], { z4.b-z7.b }, z13.b[2]\n" + ".inst 0xc15d9a23 // sdot za.s[x8, 3], { z16.b-z19.b }, z13.b[2]\n" + "ble 29f\n" + ".inst 0xa0400205 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x16]\n" + "addvl x16, x16, #2\n" + ".inst 0xa0400327 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x25]\n" + "addvl x25, x25, #2\n" + ".inst 0xa0400361 // ldnt1b { z0.b-z1.b }, pn8.b/Z, [x27]\n" + ".inst 0xa0400303 // ldnt1b { z2.b-z3.b }, pn8.b/Z, [x24]\n" + ".inst 0xa0400131 // ldnt1b { z16.b-z17.b }, pn8.b/Z, [x9]\n" + ".inst 0xc15d9ca0 // sdot za.s[x8, 0], { z4.b-z7.b }, z13.b[3]\n" + ".inst 0xa04002d3 // ldnt1b { z18.b-z19.b }, pn8.b/Z, [x22]\n" + ".inst 0xa0400385 // ldnt1b { z4.b-z5.b }, pn8.b/Z, [x28]\n" + ".inst 0xc15d9c21 // sdot za.s[x8, 1], { z0.b-z3.b }, z13.b[3]\n" + ".inst 0xa0400347 // ldnt1b { z6.b-z7.b }, pn8.b/Z, [x26]\n" + ".inst 0xc15d9e22 // sdot za.s[x8, 2], { z16.b-z19.b }, z13.b[3]\n" + ".inst 0xc15d9ca3 // sdot za.s[x8, 3], { z4.b-z7.b }, z13.b[3]\n" + "29:" // Width 4: Multiply loop: multiply skip + "tbnz %x[flags], #31, 30f\n" + "sdot z25.s, z13.b, z27.b\n" + "30:" // Width 4: Multiply loop: unique 8: skip row sum + ".inst 0xc0060c04 // mova { z4.d-z7.d }, za.d[x8, #0]\n" + ".inst 0xa0404202 // ld1w { z2.s-z3.s }, pn8.b/Z, [x16]\n" + "add x22, %x[k_args], %[c_offset]\n" + "add x21, %x[k_args], %[minval]\n" + ".inst 0xa040432c // ld1w { z12.s-z13.s }, pn8.b/Z, [x25]\n" + "add x20, %x[k_args], %[maxval]\n" + ".inst 0xc0060c3c // mova { z28.d-z31.d }, za.d[x8, #1]\n" + "add x16, x16, x12, LSL #1\n" + "ld1rw { z0.s }, p2/Z, [x22]\n" + "add x25, x25, x12, LSL #1\n" + ".inst 0xc0060c54 // mova { z20.d-z23.d }, za.d[x8, #2]\n" + "ld1rw { z1.s }, p2/Z, [x21]\n" + ".inst 0xc0060c68 // mova { z8.d-z11.d }, za.d[x8, #3]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + "ld1rw { z17.s }, p2/Z, [x20]\n" + "fmul z4.s, z4.s, z2.s\n" + "fmul z5.s, z5.s, z3.s\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmul z6.s, z6.s, z12.s\n" + "fmul z7.s, z7.s, z13.s\n" + ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1b8e084 // frintn { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc131e084 // fcvtzs { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc1a0ab04 // add { z4.s-z7.s }, { z4.s-z7.s }, z0.s\n" + ".inst 0xc1b1cc24 // sclamp { z4.s-z7.s }, z1.s, z17.s\n" + "uzp1 z4.h, z4.h, z5.h\n" + "uzp1 z16.h, z6.h, z7.h\n" + "uzp1 z4.b, z4.b, z16.b\n" + "st1b { z4.b }, p2, [x14]\n" + ".inst 0xa1404212 // ld1w { z18.s, z26.s }, pn8.b/Z, [x16]\n" + "add x16, x16, x12, LSL #1\n" + ".inst 0xa0404324 // ld1w { z4.s-z5.s }, pn8.b/Z, [x25]\n" + "add x25, x25, x12, LSL #1\n" + "fmul z28.s, z28.s, z18.s\n" + "fmul z29.s, z29.s, z26.s\n" + "fmul z30.s, z30.s, z4.s\n" + "fmul z31.s, z31.s, z5.s\n" + ".inst 0xc1b8e39c // frintn { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc131e39c // fcvtzs { z28.s-z31.s }, { z28.s-z31.s }\n" + ".inst 0xc1a0ab1c // add { z28.s-z31.s }, { z28.s-z31.s }, z0.s\n" + ".inst 0xc1b1cc3c // sclamp { z28.s-z31.s }, z1.s, z17.s\n" + "uzp1 z28.h, z28.h, z29.h\n" + "uzp1 z16.h, z30.h, z31.h\n" + "uzp1 z28.b, z28.b, z16.b\n" + "st1b { z28.b }, p2, [x14, #1, MUL VL]\n" + ".inst 0xa1404207 // ld1w { z7.s, z15.s }, pn8.b/Z, [x16]\n" + "add x16, x16, x12, LSL #1\n" + ".inst 0xa1404324 // ld1w { z4.s, z12.s }, pn8.b/Z, [x25]\n" + "add x25, x25, x12, LSL #1\n" + "fmul z20.s, z20.s, z7.s\n" + "fmul z21.s, z21.s, z15.s\n" + "fmul z22.s, z22.s, z4.s\n" + "fmul z23.s, z23.s, z12.s\n" + ".inst 0xc1b8e294 // frintn { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc131e294 // fcvtzs { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc1a0ab14 // add { z20.s-z23.s }, { z20.s-z23.s }, z0.s\n" + ".inst 0xc1b1cc34 // sclamp { z20.s-z23.s }, z1.s, z17.s\n" + "uzp1 z20.h, z20.h, z21.h\n" + "uzp1 z16.h, z22.h, z23.h\n" + "uzp1 z20.b, z20.b, z16.b\n" + "st1b { z20.b }, p2, [x14, #2, MUL VL]\n" + ".inst 0xa1404206 // ld1w { z6.s, z14.s }, pn8.b/Z, [x16]\n" + ".inst 0xa1404327 // ld1w { z7.s, z15.s }, pn8.b/Z, [x25]\n" + "fmul z8.s, z8.s, z6.s\n" + "fmul z9.s, z9.s, z14.s\n" + "fmul z10.s, z10.s, z7.s\n" + "fmul z11.s, z11.s, z15.s\n" + ".inst 0xc1b8e108 // frintn { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc131e108 // fcvtzs { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc1a0ab08 // add { z8.s-z11.s }, { z8.s-z11.s }, z0.s\n" + ".inst 0xc1b1cc28 // sclamp { z8.s-z11.s }, z1.s, z17.s\n" + "uzp1 z8.h, z8.h, z9.h\n" + "uzp1 z16.h, z10.h, z11.h\n" + "uzp1 z8.b, z8.b, z16.b\n" + "st1b { z8.b }, p1, [x14, #3, MUL VL]\n" + "addvl x14, x14, #4\n" + "subs x13, x13, #0x4\n" + "mov x16, x23\n" + "sub %x[N], %x[N], x15, LSL #2\n" + "bgt 4b\n" + "32:" // Exit + ".inst 0xd503467f // SMSTOP\n" + : [N] "+&r"(N), [flags] "+&r"(flags) + : [A_ptr] "r"(A_ptr), [B_ptr] "r"(B_ptr), [K] "r"(K), [c_offset] "I"(offsetof(KernelArgs, c_offset)), + [k_args] "r"(&k_args), [maxval] "I"(offsetof(KernelArgs, maxval)), [minval] "I"(offsetof(KernelArgs, minval)), + [output_ptr] "r"(output_ptr) + : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", + "x27", "x28", "x8", "x9", "z0", "z1", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", + "z2", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", + "z6", "z7", "z8", "z9"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h new file mode 100644 index 00000000..cb2932f9 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h @@ -0,0 +1,116 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// -# kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme + +/// -------------------------------------------------- + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void); + +/// Gets the offset in bytes to the data element in the LHS matrix buffer. +/// +/// @param[in] m_idx Row index. This must be 0. +/// @param[in] k Columns of unpacked LHS. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of n_step +/// @param[in] k Number of rows in the unpacked RHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be 0 +/// @param[in] n_idx Column index. Must be multiple of n_step +/// @param[in] dst_stride Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot. +/// +/// @param[in] m Number of output rows to be computed. This must be 1. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Common dimension of the LHS and RHS operand. +/// @param[in] lhs LHS matrix buffer. +/// @param[in] lhs_stride Row stride in bytes of the LHS matrix. Unused parameter. +/// @param[in] rhs_packed Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] params Quantization parameters +void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h new file mode 100644 index 00000000..9b59dd6c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h @@ -0,0 +1,50 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_qai8_qai8_qsi8cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_lhs_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_qai8_qai8_qsi8cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, + size_t dst_stride_row, size_t dst_stride_col, const struct kai_matmul_requantize32_params* params); + +/// Micro-kernel interface +struct kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel { + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_nr_func_t get_nr; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_kr_func_t get_kr; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_sr_func_t get_sr; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_lhs_offset_func_t get_lhs_offset; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_qai8_qai8_qsi8cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_qai8_qai8_qsi8cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 81ffaf4d..93bae72b 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -14,6 +14,8 @@ #include #include "kai/kai_common.h" +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" @@ -91,7 +93,8 @@ const static RhsPackKernel rhs_pack = { struct MatMulVariant { std::string_view name; ///< Test identification - MatMulShape acc; ///< Accumulator shape + MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) + MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) std::function is_supported; ///< HW support check @@ -103,7 +106,12 @@ struct MatMulVariant { const std::array gemm_variants = { MatMulVariant{ .name = "matmul_qai8_qai8p_qsi8cxp", - .acc{ + .acc_pack{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + .acc_step{ .m = 2 * get_sme_vector_length(), .n = 2 * get_sme_vector_length(), .k = sizeof(int32_t) / sizeof(int8_t), @@ -139,6 +147,45 @@ const std::array gemm_variants = { }, }; +const std::array gemv_variants = { + MatMulVariant{ + .name = "matmul_qai8_qai8_qsi8cxp", + .acc_pack{ + .m = 1, + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + .acc_step{ + .m = 1, + .n = 16 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + + .is_supported = cpu_has_sme2, + + .lhs_pack = std::nullopt, + .rhs_pack = rhs_pack, + .matmul = MatMulKernel{ + .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_mr = []() -> size_t { return 1; }, + .get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_packed_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .matmul = + [](size_t m, size_t n, size_t k, const void* lhs, const void* rhs, void* dst, size_t dst_stride_row, + size_t dst_stride_col, const kai_matmul_requantize32_params* quant_param) { + kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( + m, n, k, lhs, sizeof(int8_t), rhs, dst, dst_stride_row, dst_stride_col, quant_param); + }, + }, + }, +}; + constexpr uint64_t seed = 0; ///< Random seed used for tests constexpr float output_clamp_rate = 0.1F; ///< Clamping range in ration of output @@ -181,6 +228,21 @@ struct TestReference { Buffer packed_rhs; }; +/// Make sure that interface matches +static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface + [[maybe_unused]] = { + .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, + .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, +}; + /// Generate test reference data static TestReference get_test_reference(const MatMulShape& shape, const MatMulVariant& variant) { // ============================================================ @@ -196,27 +258,39 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa // * LHS: 8-bit asymmetric per-matrix quantization. // * RHS: 8-bit symmetric per-channel quantization. // * Bias: 32-bit symmetric per-channel quantization. + // + // Treat entire LHS as one row vector, to calculate one single pair of auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] = quantize_asymmetric_per_block_dynamic( lhs_f32.data(), 1, shape.m * shape.k, shape.m * shape.k); const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); + // Transpose, then quantize symmetrically, then transpose back. This will give one + // quantization value for each column const auto rhs_f32_t = transpose(rhs_f32.data(), shape.k, shape.n); auto [rhs_qsi8_t, rhs_scales] = quantize_symmetric_per_block_dynamic(rhs_f32_t.data(), shape.n, shape.k, shape.k); auto rhs_qsi8 = transpose(rhs_qsi8_t.data(), shape.n, shape.k); + // Multiply all bias values with the LHS scale const auto bias_scales = mul(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n); + // Calculate quantized bias values, by treating bias as column, and + // scale using RHS scales. This will scale each bias value indiviually auto bias_qsi32 = quantize_symmetric_per_block(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); // Runs the reference implementation of matmul to produce floating-point result. const auto ref_dst_f32 = matmul_nt_t_quantized( - shape.m, shape.n, shape.k, lhs_qai8.data(), &lhs_scale, &lhs_zero_point, shape.m, shape.k, - rhs_qsi8_t.data(), rhs_scales.data(), nullptr, 1, shape.k, bias_qsi32.data(), bias_scales.data(), nullptr, - 1); + shape.m, shape.n, shape.k, // matmul shape + lhs_qai8.data(), &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point + shape.m, shape.k, // LHS quantization window shape + rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors + 1, shape.k, // RHS quantization window shape + bias_qsi32.data(), bias_scales.data(), nullptr, // Bias, scaling and zero points + 1 // Bias quantization window shape + ); // Computes the output quantization information and clamping limits. // @@ -247,17 +321,20 @@ static TestReference get_test_reference(const MatMulShape& shape, const MatMulVa const auto ref_dst_f32_clamped = clamp(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max); auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block( - ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, 1, shape.m * shape.n, shape.m * shape.n); + ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, // values, scales, zero point + 1, shape.m * shape.n, // data shape + shape.m * shape.n // quantization window width + ); // Runs the reference implementation of the packing functions. // // The reference packing functions cannot be executed earlier // because we need the reference floating-point output first to have // the quantization information. - auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc.m, variant.acc.k); + auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k); auto packed_rhs = matmul_pack_rhs_nxk_static_quantized( rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, - variant.acc.n, variant.acc.k); + variant.acc_pack.n, variant.acc_pack.k); return { .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, @@ -287,20 +364,22 @@ static void test_lhs_pack( KAI_ASSUME(variant.lhs_pack.has_value()); const auto imp_packed_lhs_size = - variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc.m, variant.acc.k, 1); + variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size()); Buffer imp_packed_lhs(imp_packed_lhs_size); const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); - const auto imp_packed_lhs_offset = - variant.lhs_pack->get_packed_lhs_offset(output_area.start_row(), shape.k, variant.acc.m, variant.acc.k, 1); + const auto imp_packed_lhs_offset = variant.lhs_pack->get_packed_lhs_offset( + output_area.start_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); variant.lhs_pack->pack( - output_area.height(), shape.k, variant.acc.m, variant.acc.k, 1, 0, reference.lhs_qai8.data() + imp_lhs_offset, - shape.k * sizeof(int8_t), imp_packed_lhs.data() + imp_packed_lhs_offset); + output_area.height(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1, 0, + reference.lhs_qai8.data() + imp_lhs_offset, shape.k * sizeof(int8_t), + imp_packed_lhs.data() + imp_packed_lhs_offset); const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m - ? variant.lhs_pack->get_packed_lhs_offset(output_area.end_row(), shape.k, variant.acc.m, variant.acc.k, 1) + ? variant.lhs_pack->get_packed_lhs_offset( + output_area.end_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1) : imp_packed_lhs_size; for (size_t i = 0; i < reference.packed_lhs.size(); ++i) { @@ -330,7 +409,7 @@ static void test_rhs_pack( }; variant.rhs_pack.pack( - 1, output_area.width(), shape.k, variant.acc.n, variant.acc.k, 1, shape.n * sizeof(int8_t), + 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1, shape.n * sizeof(int8_t), reference.rhs_qsi8.data() + imp_rhs_offset, reference.bias_qsi32.data() + imp_bias_offset, reference.rhs_scales.data() + imp_scale_offset, imp_packed_rhs.data() + imp_packed_rhs_offset, 0, &imp_pack_rhs_params); @@ -339,16 +418,22 @@ static void test_rhs_pack( ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k) : imp_packed_rhs_size; + size_t mismatches = 0; for (size_t i = 0; i < reference.packed_rhs.size(); ++i) { if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { - ASSERT_EQ(imp_packed_rhs[i], reference.packed_rhs[i]); + if (imp_packed_rhs[i] != reference.packed_rhs[i]) { + mismatches += 1; + } } else { - ASSERT_EQ(imp_packed_rhs[i], 0); + if (imp_packed_rhs[i] != 0) { + mismatches += 1; + } } } + ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; } -/// Test MatMul of GEMM like kernel +/// Test MatMul of GEMM/GEMV like kernel static void test_matmul( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n); @@ -377,6 +462,7 @@ static void test_matmul( reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), sizeof(int8_t), &imp_main_params); + size_t mismatches = 0; for (size_t y = 0; y < shape.m; ++y) { for (size_t x = 0; x < shape.n; ++x) { const auto i = y * shape.n + x; @@ -388,11 +474,10 @@ static void test_matmul( const auto error = std::abs(imp_value - ref_value); const auto threshold = in_area ? 1 : 0; - if (error > threshold) { - ASSERT_EQ(imp_value, ref_value); - } + mismatches += static_cast(error > threshold); } } + ASSERT_EQ(mismatches, 0) << "There are mismatched between reference result actual result"; } using ThisTest = testing::TestWithParam>; @@ -426,24 +511,26 @@ TEST_P(ThisTest, EndToEnd) { const auto imp_kr = variant.matmul.get_kr(); const auto imp_sr = variant.matmul.get_sr(); - ASSERT_EQ(imp_mr, variant.acc.m); - ASSERT_EQ(imp_nr, variant.acc.n); - ASSERT_EQ(imp_kr, variant.acc.k); + ASSERT_EQ(imp_mr, variant.acc_pack.m); + ASSERT_EQ(imp_nr, variant.acc_pack.n); + ASSERT_EQ(imp_kr, variant.acc_pack.k); ASSERT_EQ(imp_sr, 1); + // Check that stepping is a multiple of accumulation const auto imp_m_step = variant.matmul.get_m_step(); const auto imp_n_step = variant.matmul.get_n_step(); + ASSERT_EQ(imp_m_step, variant.acc_step.m); + ASSERT_EQ(imp_n_step, variant.acc_step.n); - ASSERT_EQ(imp_m_step, variant.acc.m); - ASSERT_EQ(imp_n_step, variant.acc.n); - - // Test kernels - const auto output_area = output_portion.compute_portion(shape.m, shape.n, variant.acc.m, variant.acc.n); + // Test kernels. Note that packing and actual stepping might not be the same + const auto pack_portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_pack.m, variant.acc_pack.n); + const auto matmul_portion = + output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); if (variant.lhs_pack.has_value()) { - test_lhs_pack(shape, variant, output_area, reference); + test_lhs_pack(shape, variant, pack_portion, reference); } - test_rhs_pack(shape, variant, output_area, reference); - test_matmul(shape, variant, output_area, reference); + test_rhs_pack(shape, variant, pack_portion, reference); + test_matmul(shape, variant, matmul_portion, reference); } INSTANTIATE_TEST_SUITE_P( @@ -451,21 +538,23 @@ INSTANTIATE_TEST_SUITE_P( testing::Combine( testing::ValuesIn(gemm_variants), testing::ValuesIn({ - MatMulShape{1, 1, 1}, // - MatMulShape{ - 2 * get_sme_vector_length(), 2 * get_sme_vector_length(), - sizeof(int32_t) / sizeof(int8_t)}, // - MatMulShape{20, 30, 40}, // - MatMulShape{1, 49, 21}, // - MatMulShape{23, 1, 43}, // - MatMulShape{32, 14, 1}, // - MatMulShape{123, 85, 45}, // - MatMulShape{130, 130, 6}, + // clang-format off + MatMulShape{ 1, 1, 1}, + MatMulShape{ 1, 49, 21}, + MatMulShape{ 20, 30, 40}, + MatMulShape{ 23, 1, 43}, + MatMulShape{ 32, 14, 1}, + MatMulShape{ 64, 64, 4}, + MatMulShape{123, 85, 45}, + MatMulShape{130, 130, 6}, + // clang-format on }), testing::ValuesIn({ - MatrixPortion(0, 0, 1, 1), // Full matrix. - MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. - MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. + // clang-format off + MatrixPortion( 0, 0, 1, 1), // Full matrix. + MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + // clang-format on })), [](const auto& info) -> std::string { return test_description( @@ -474,4 +563,37 @@ INSTANTIATE_TEST_SUITE_P( std::get(info.param)); }); +INSTANTIATE_TEST_SUITE_P( + matmul_clamp_qai8_qai8_qsi8cxp, ThisTest, + testing::Combine( + testing::ValuesIn(gemv_variants), + testing::ValuesIn({ + // clang-format off + MatMulShape{1, 1, 1}, + MatMulShape{1, 16, 4}, + MatMulShape{1, 16, 16}, + MatMulShape{1, 17, 4}, + MatMulShape{1, 32, 32}, + MatMulShape{1, 33, 200}, + MatMulShape{1, 64, 4}, + MatMulShape{1, 65, 4}, + MatMulShape{1, 300, 10}, + MatMulShape{1, 512, 4}, + MatMulShape{1, 1523, 10}, + // clang-format on + }), + testing::ValuesIn({ + // clang-format off + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, .5, 1, .5), // Right half + MatrixPortion(0, 0, 1, .5), // Left half + MatrixPortion(0, .25, 1, .5) // Middle half + // clang-format on + })), + [](const auto& info) -> std::string { + return test_description( + std::get(info.param), // + std::get(info.param), // + std::get(info.param)); + }); } // namespace kai::test -- GitLab From 084f3b95cf798d722b5c4635d38919f760c1463f Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Thu, 13 Feb 2025 14:52:44 +0100 Subject: [PATCH 2/4] Address review comments Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 10 +++++----- ...ul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c | 1 - test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 1 - 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b9009b0c..037bd611 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,18 +188,18 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2 - kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) diff --git a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c index 6fcb8108..288c207f 100644 --- a/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c @@ -83,7 +83,6 @@ void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( KAI_ASSUME(m == 1); typedef struct { - int32_t b_offset; int32_t c_offset; int32_t maxval; int32_t minval; diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 93bae72b..252dd2fb 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -13,7 +13,6 @@ #include #include -#include "kai/kai_common.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" -- GitLab From 34e70648c64d7d39bba4f8069a4a4e8ea4aa64e9 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Thu, 13 Feb 2025 15:06:00 +0100 Subject: [PATCH 3/4] Re-add block specific test Signed-off-by: Emil Ohlsson --- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 252dd2fb..07edbe3d 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -547,6 +547,11 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{123, 85, 45}, MatMulShape{130, 130, 6}, // clang-format on + MatMulShape{ + kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(), + kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(), + kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(), + }, // Not able to add *_step for this kernel as you cannot repeat same values }), testing::ValuesIn({ // clang-format off @@ -580,6 +585,16 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{1, 512, 4}, MatMulShape{1, 1523, 10}, // clang-format on + MatMulShape{ + 1, + kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), + kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), + }, // Not able to add *_step for this kernel as you cannot repeat same values + MatMulShape{ + 1, + kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), + kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), + }, // }), testing::ValuesIn({ // clang-format off -- GitLab From 7ddb392c6acc80a369e91fef59af76e38f6fdbe6 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Mon, 17 Feb 2025 11:47:59 +0100 Subject: [PATCH 4/4] Removing block tests The block test causes some test issues that are unrelated to functionality. Will revisit in separate change Signed-off-by: Emil Ohlsson --- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 07edbe3d..252dd2fb 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -547,11 +547,6 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{123, 85, 45}, MatMulShape{130, 130, 6}, // clang-format on - MatMulShape{ - kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(), - kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(), - kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(), - }, // Not able to add *_step for this kernel as you cannot repeat same values }), testing::ValuesIn({ // clang-format off @@ -585,16 +580,6 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{1, 512, 4}, MatMulShape{1, 1523, 10}, // clang-format on - MatMulShape{ - 1, - kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), - kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), - }, // Not able to add *_step for this kernel as you cannot repeat same values - MatMulShape{ - 1, - kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), - kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(), - }, // }), testing::ValuesIn({ // clang-format off -- GitLab