From 8751d8087ca2b846531410eed0b2fef92806addd Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 19 May 2025 14:55:53 +0100 Subject: [PATCH 01/14] Matmul Micro-kernels F32 <- QAI8DXP(LHS) x QSI8CXP(RHS) optimized for SME * Micro-kernels (1xN) to compute the matrix multiplication of dynamically quantized asymmetric 8-bit integer with per-channel quantization (QAI8DX) LHS matrix and quantized symmetric 8-bit integer with per-channel quantization (QSI4CX) RHS matrix and the accumulation of the result into a single-precision (F32) output, optimized for SME2 technology. Signed-off-by: Anitha Raj --- CHANGELOG.md | 3 + CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 317 ++++++++++++++++++ ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h | 137 ++++++++ .../matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp | 27 +- 6 files changed, 475 insertions(+), 11 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h diff --git a/CHANGELOG.md b/CHANGELOG.md index e9f740d3..363e5378 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New SME2 micro-kernels: + - Matrix multiplication (1xN) of QAI8DX LHS and QSI8CX RHS to produce F32 output. + ## v1.8.0 - New Advanced SIMD micro-kernels: diff --git a/CMakeLists.txt b/CMakeLists.txt index ea7b9926..a4be3ed4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -257,6 +257,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c ) add_library(kleidiai) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 2da2800c..d126af80 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -169,6 +169,7 @@ SME2_KERNELS = [ "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c new file mode 100644 index 00000000..8f3ddc7a --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -0,0 +1,317 @@ +// +// SPDX-FileCopyrightText: Copyrigh 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error "This file must be compiled for AArch64, FEAT_SVE2" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" + +#include +#include + +#include "kai/kai_common.h" + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 1; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; // multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_qvalue_rhs = 1; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + size_t lhs_packed_stride = mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); + + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + + size_t rhs_packed_stride = nr * (k_internal * kai_num_bytes_qvalue_rhs); + rhs_packed_stride += nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { + return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + return (n_idx / nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const uint64_t k_internal = kai_k_roundedup(k); + const uint64_t lhs_stride = kai_get_lhs_packed_stride(k); + const uint64_t rhs_stride = kai_get_rhs_packed_stride(k); + const uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + + const uint64_t rhs_row_bytes = nr * k_internal; + const uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride); + + /* + * x11: zero = 0 // MUST BE x8-x11 + * x15: n initialized as n + * x19: nr initialized as nr + * x21: lhs_packed initialized as lhs_packed + * x22: n_idx + * x23: k_idx + * x24: RHS block ptr + * x25: RHS end ptr + * x26: rhs_packed + * x27: dst_ptr + * x28: tmp_1 + */ + + __asm__ volatile( + + // Setup + " .inst 0xd503477f // smstart \n" + " mov x11, #0 \n" + " mov x15, %[n] \n" + " mov x19, %[nr] \n" + " mov x21, %[lhs_packed] \n" + " ptrue p0.b \n" + " .inst 0x25207810 // ptrue pn8.b \n" + // predicate to load nr words for the RHS sums and scaling factors (should be exactly all true) + " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n" + " dup z30.s, %w[scalar_min] \n" + " dup z31.s, %w[scalar_max] \n" + + // lhs matrix row loop + "1: \n" + // Reset rhs matrix ptr + " mov x26, %[rhs_packed] \n" + // Reset dst_ptr to dst of next GEMV result + " mov x27, %[dst_ptr] \n" + // Reset n index + " mov x22, #0 \n" + // whilelt pn12.s, x22, %[n], vlx4 + " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" + + // rhs matrix row loop (transposed so theoretical columns) + "2: \n" + + // Reset rhs block ptr to start of row + " mov x24, x26 \n" + " add x25, x26, %[rhs_row_bytes] \n" + " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" + " addvl x28, x24, #4 \n" + " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" + " addvl x28, x28, #4 \n" + " .inst 0x25396795 // whilelt pn13.b, x28, x25, vlx4 \n" + " addvl x28, x28, #4 \n" + " .inst 0x25396796 // whilelt pn14.b, x28, x25, vlx4 \n" + " mov x23, #0 \n" + " whilelt p1.b, x23, %[k_internal] \n" + // Zero for sdot accumulation in inner loop + " .inst 0xc00800ff // zero {za} \n" + + // before k loop + "3: \n" + + // Load lhs + " ld1rqb { z0.b }, p1/z , [x21, x23] \n" + + // Load w + " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n" + " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n" + + " .inst 0xc150f220 // sdot za.s[w11,0, vgx4], {z16.b-z19.b}, z0.b[0] \n" + " .inst 0xc150f6a0 // sdot za.s[w11,0, vgx4], {z20.b-z23.b}, z0.b[1] \n" + + " .inst 0xa0429710 // ld1b { z16.b - z19.b }, pn13/z, [x24,#0x8, mul vl] \n" + " .inst 0xa0439b14 // ld1b { z20.b - z23.b }, pn14/z, [x24,#0xC, mul vl] \n" + " .inst 0xc150fa20 // sdot za.s[w11,0, vgx4], {z16.b-z19.b}, z0.b[2] \n" + " .inst 0xc150fea0 // sdot za.s[w11,0, vgx4], {z20.b-z23.b}, z0.b[3] \n" + + // End K block loop + " addvl x24, x24, #16 \n" + " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" + " addvl x28, x24, #4 \n" + " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" + " addvl x28, x28, #4 \n" + " .inst 0x25396795 // whilelt pn13.b, x28, x25, vlx4 \n" + " addvl x28, x28, #4 \n" + " .inst 0x25396796 //whilelt pn14.b, x28, x25, vlx4 \n" + " add x23, x23, #16 \n" + " whilelt p1.b, x23, %[k_internal] \n" + " b.first 3b \n" + + // Finish of accumulators with scaling factors and zero points + + // Load lhs zero point + " add x28, x21, %[k_internal] \n" + " ld1rw { z2.s }, p0/z , [x28] \n" + // Load lhs scaling factor + " ld1rw { z3.s }, p0/z , [x28, #4] \n" + // Load rhs sums + " add x28, x26, %[rhs_row_bytes] \n" + " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n" + // Load rhs scaling factors + " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n" + // Load biases + " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n" + + // Get accumulated value out of ZA + " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n" + + // za contains a * w, which needs to be done + z * wsum -> smla + // zero point * rhs row sum + " mla z4.s, p0/m, z20.s, z2.s \n" + " mla z5.s, p0/m, z21.s, z2.s \n" + " mla z6.s, p0/m, z22.s, z2.s \n" + " mla z7.s, p0/m, z23.s, z2.s \n" + + // Convert to float + " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n" + + // lhs scaling factor * rhs scaling factor + " fmul z24.s, z24.s, z3.s \n" + " fmul z25.s, z25.s, z3.s \n" + " fmul z26.s, z26.s, z3.s \n" + " fmul z27.s, z27.s, z3.s \n" + + // Bias + combined scaling factor * combined accumulator + " fmla z12.s, p0/m, z24.s, z4.s \n" + " fmla z13.s, p0/m, z25.s, z5.s \n" + " fmla z14.s, p0/m, z26.s, z6.s \n" + " fmla z15.s, p0/m, z27.s, z7.s \n" + + // Clamp + " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n" + + // Store + " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n" + + // End rhs row loop + " add x26, x26, %[rhs_stride] \n" + // nr == svlb + " addvl x22, x22, #1 \n" + // whilelt pn12.s, x22, %[n], vlx4 + " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" + " b.lt 2b \n" + + // End lhs row loop + " add %[dst_ptr], %[dst_ptr], %[dst_stride_row] \n" + " add x21, x21, %[lhs_stride] \n" + " cmp x21, %[lhs_end_ptr] \n" + " b.lt 1b \n" + + " .inst 0xd503467f // smstop \n" + + : [dst_ptr] "+r"(dst) + : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed), [rhs_packed] "r"(rhs_packed), + [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min), [scalar_max] "r"(scalar_max), + [k_internal] "r"(k_internal), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), [nr] "r"(nr), + [rhs_row_bytes] "r"(rhs_row_bytes), [lhs_end_ptr] "r"(lhs_end_ptr) + : "x11", "x15", "x19", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9", "p10", + "p11", "p12", "p13", "p14", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", + "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", + "z28", "z29", "z30", "z31", +#ifdef __ARM_STATE_ZA + "za", +#endif + "memory", "cc"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h new file mode 100644 index 00000000..c7d85e8c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h @@ -0,0 +1,137 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed +/// RHS matrix: Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) and packed. +/// Output tile: (rows x cols) = 1 x 4vl +/// +/// Features used: sme2 +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index 1bad32a6..36fb8f8c 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -15,6 +15,7 @@ #include #include +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" @@ -36,16 +37,19 @@ namespace kai::test { -static const std::array, 4> - variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = { - {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), - "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod}, - {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod), - "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod}, - {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod), - "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod}, - {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm), - "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm}}}; +static const std::array, 5> + variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = {{ + {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), + "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod), + "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod), + "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm), + "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot), + "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", cpu_has_sme2}, + }}; class MatMulTest_f32_qai8dxp_qsi8cxp : public ::testing::TestWithParam {}; @@ -339,7 +343,8 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( MatMulShape{17, 33, 67}, // MatMulShape{19, 35, 63}, // - MatMulShape{1, 27, 31}), + MatMulShape{1, 27, 31}, // + MatMulShape{1, 65, 35}, MatMulShape{1, 64, 65}, MatMulShape{1, 63, 15}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From 9e19d22500eadae49e6e6bba3caca9465666a484 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 20 May 2025 13:14:14 +0100 Subject: [PATCH 02/14] Add large shape to test non-zero n_idx Signed-off-by: Anitha Raj --- test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index 36fb8f8c..86a0833c 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -344,7 +344,10 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{17, 33, 67}, // MatMulShape{19, 35, 63}, // MatMulShape{1, 27, 31}, // - MatMulShape{1, 65, 35}, MatMulShape{1, 64, 65}, MatMulShape{1, 63, 15}), + MatMulShape{1, 65, 35}, // + MatMulShape{1, 64, 65}, // + MatMulShape{1, 63, 15}, // + MatMulShape{1, 130, 15}, ), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From b5005af63ad3193bbcbc0a500fd8dd691370b39d Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 20 May 2025 13:31:10 +0100 Subject: [PATCH 03/14] Fix typo Signed-off-by: Anitha Raj --- test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index 86a0833c..b564996b 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -347,7 +347,7 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{1, 65, 35}, // MatMulShape{1, 64, 65}, // MatMulShape{1, 63, 15}, // - MatMulShape{1, 130, 15}, ), + MatMulShape{1, 130, 15}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From 3b5b98383709e328d12aee9288ac83aa08e06216 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 20 May 2025 22:16:33 +0100 Subject: [PATCH 04/14] Address review comments Signed-off-by: Anitha Raj --- CMakeLists.txt | 2 +- ...i_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a4be3ed4..2eb1609c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -250,6 +250,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c @@ -257,7 +258,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c ) add_library(kleidiai) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index 8f3ddc7a..5dcfb37b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -149,7 +149,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( const uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride); /* - * x11: zero = 0 // MUST BE x8-x11 + * x11: zero = 0 * x15: n initialized as n * x19: nr initialized as nr * x21: lhs_packed initialized as lhs_packed -- GitLab From 9a6013a83a4350b39e15a57ae1ba650fc7adab1f Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 22 May 2025 12:06:28 +0100 Subject: [PATCH 05/14] Add Matrix multiplication (MxN) of QAI8DX LHS and QSI8CX RHS to produce F32 output micro-kernel * Micro-kernels (MxN) to compute the matrix multiplication of dynamically quantized asymmetric 8-bit integer with per-channel quantization (QAI8DX) LHS matrix and quantized symmetric 8-bit integer with per-channel quantization (QSI8CX) RHS matrix and the accumulation of the result into a single-precision (F32) output, optimized for SME2 technology. Signed-off-by: Anitha Raj --- CHANGELOG.md | 1 + CMakeLists.txt | 1 + kai/ukernels/matmul/BUILD.bazel | 1 + ...8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c | 323 ++++++++++++++++++ ...8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h | 129 +++++++ ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 12 +- .../matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp | 11 +- 7 files changed, 470 insertions(+), 8 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 363e5378..0605f0e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New SME2 micro-kernels: - Matrix multiplication (1xN) of QAI8DX LHS and QSI8CX RHS to produce F32 output. + - Matrix multiplication (MxN) of QAI8DX LHS and QSI8CX RHS to produce F32 output. ## v1.8.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 2eb1609c..9b230ea1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -249,6 +249,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index d126af80..b4df8170 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -169,6 +169,7 @@ SME2_KERNELS = [ "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c new file mode 100644 index 00000000..ccb3aef7 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -0,0 +1,323 @@ +// +// SPDX-FileCopyrightText: Copyrigh 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Do not flag up inline assembly blocks +#pragma GCC diagnostic ignored "-Woverlength-strings" + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error "This file must be compiled for AArch64, FEAT_SVE2" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +// Compute args +static const size_t kai_m_step = 1; // multiple of vector length +static const size_t kai_n_step = 4; // multiple of vector length +// Packing args +static const size_t kai_mr = 1; // multiple of vector length +static const size_t kai_nr = 4; // multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_qvalue_rhs = 1; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + size_t lhs_packed_stride = mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + size_t rhs_packed_stride = nr * (k_internal * kai_num_bytes_qvalue_rhs); + rhs_packed_stride += nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_m_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_n_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + + return (m_idx / mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + + return (n_idx / nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + + return ((n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride); +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) { + return (m * n * kai_num_bytes_dst_value); +} + +void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + KAI_ASSERT(n > 0); + KAI_ASSERT(m > 0); + + const float scalar_bounds[2] = {scalar_min, scalar_max}; + + typedef struct { + size_t lhs_stride; + size_t rhs_stride; + size_t mr; + size_t nr; + size_t m_blk; + size_t dst_inc; + size_t rhs_row_bytes; + } KernelArgs; + + KernelArgs ka; + + // Constants + ka.mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + ka.nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + ka.lhs_stride = kai_get_lhs_packed_stride(k); + ka.rhs_stride = kai_get_rhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); + ka.m_blk = k_internal * ka.mr; + ka.dst_inc = ka.mr * dst_stride_row; + ka.rhs_row_bytes = ka.nr * k_internal; + + __asm__ volatile( + " .inst 0xd503477f //smstart \n" + " mov x19, %[dst] \n" + " mov x20, %[lhs] \n" + " cntw x7 \n" + + " ptrue p2.b \n" + " ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n" + + " ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n" + + // M loop head + " mov x12, %[m] \n" + " .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n" + "1: \n" + " mov x8, %[rhs] \n" + " mov x9, x19 \n" + " mov x13, %[n] \n" + " cmp x7, x12 \n" + " csel x16, x7, x12, lt \n" + " lsl x16, x16, #2 \n" + + // N loop head + " ldr x24, [%x[args_ptr], %[offset_rhs_row_bytes]] \n" + " add x24, x24, x8 \n" + " mov x11, x8 \n" + " .inst 0x25784570 // whilelt pn8.h, x11, x24, vlx2 \n" + " addvl x11, x8, #2 \n" + " .inst 0x25784572 // whilelt pn10.h, x11, x24, vlx2 \n" + + "2: \n" + " mov x10, x20 \n" + " mov x11, x8 \n" + " mov x17, x9 \n" + " .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n" + + // K loop + " .inst 0xc00800ff //zero {za} \n" + " ldr x24, [%x[args_ptr], %[offset_m_blk]] \n" + " add x14, x10, x24 \n" + + "3: \n" + " .inst 0xa540a144 // ld1w { z4.s }, p0/z, [x10] \n" + " .inst 0x042a502a // addvl x10, x10, #1 \n" + + " .inst 0xa0402168 //ld1h { z8.h - z9.h }, pn8/z, [x11] \n" + " .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n" + " .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n" + + " .inst 0xa041296a //ld1h { z10.h - z11.h }, pn8/z, [x11, #0x2, mul vl] \n" + " .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n" + " .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n" + + " .inst 0x042b508b // addvl x11, x11, #4 \n" + " cmp x10, x14 \n" + " b.lt 3b \n" + + // RHS row sum, scale factor & bias + " .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n" + " .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n" + " .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n" + " .inst 0x042b518b //addvl x11, x11, #12 \n" + " .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + + // Store loop + " mov x14, #0 \n" + " addvl x15, x10, #1 \n" + + "4: \n" + // Load LHS Row-offset & SF + " ld1rw {z16.s}, p2/z, [x10] \n" + " ld1rw {z17.s}, p2/z, [x15] \n" + " add x10, x10, #4 \n" + " add x15, x15, #4 \n" + " scvtf z16.s, p2/m, z16.s \n" + + // offset x Row-sum + " fmul z24.s, z16.s, z0.s \n" + " fmul z25.s, z16.s, z1.s \n" + " fmul z26.s, z16.s, z2.s \n" + " fmul z27.s, z16.s, z3.s \n" + + // Scaling factors + " fmul z20.s, z17.s, z4.s \n" + " fmul z21.s, z17.s, z5.s \n" + " fmul z22.s, z17.s, z6.s \n" + " fmul z23.s, z17.s, z7.s \n" + + // Result = offset x Row-sum x SFs + " fmul z24.s, z24.s, z20.s \n" + " fmul z25.s, z25.s, z21.s \n" + " fmul z26.s, z26.s, z22.s \n" + " fmul z27.s, z27.s, z23.s \n" + + // Load inner accumulation & convert + " .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n" + " .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n" + + // Result += iacc x SF + " fmla z24.s, p2/m, z20.s, z12.s \n" + " fmla z25.s, p2/m, z21.s, z13.s \n" + " fmla z26.s, p2/m, z22.s, z14.s \n" + " fmla z27.s, p2/m, z23.s, z15.s \n" + + // Add the bias + " fadd z24.s, p2/m, z24.s, z8.s \n" + " fadd z25.s, p2/m, z25.s, z9.s \n" + " fadd z26.s, p2/m, z26.s, z10.s \n" + " fadd z27.s, p2/m, z27.s, z11.s \n" + + // CLAMP and store + " .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n" + " .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n" + + " add x17, x17, %[dst_stride_row] \n" + " add x14, x14, #4 \n" + " cmp x14, x16 \n" + " b.lt 4b \n" + + // N loop tail + " ldr x24, [%x[args_ptr], %[offset_rhs_stride]] \n" + " add x8, x8, x24 \n" + " .inst 0x04295089 // addvl x9, x9, #4 \n" + " ldr x24, [%x[args_ptr], %[offset_rhs_row_bytes]] \n" + " add x24, x24, x8 \n" + " mov x11, x8 \n" + " .inst 0x25784570 // whilelt pn8.h, x11, x24, vlx2\n" + " addvl x11, x8, #2 \n" + " .inst 0x25784572 // whilelt pn10.h, x11, x24, vlx2 \n" + " ldr x24, [%x[args_ptr], %[offset_nr]] \n" + " sub x13, x13, x24 \n" + " cmp xzr, x13 \n" + " b.mi 2b \n" + + // M loop tail + " ldr x24, [%x[args_ptr], %[offset_lhs_stride]] \n" + " add x20, x20, x24 \n" + " ldr x24, [%x[args_ptr], %[offset_dst_inc]] \n" + " add x19, x19, x24 \n" + " ldr x24, [%x[args_ptr], %[offset_mr]] \n" + " sub x12, x12, x24 \n" + " whilelt p0.s, xzr, x12 \n" + " b.mi 1b \n" + + "5: \n" + " .inst 0xd503467f //smstop \n" + : + : [m] "r"(m), [n] "r"(n), [k] "r"(k), [dst] "r"(dst), [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), + [dst_stride_row] "r"(dst_stride_row), [scalar_bounds] "r"(scalar_bounds), [args_ptr] "r"(&ka), + [offset_m_blk] "I"(offsetof(KernelArgs, m_blk)), [offset_mr] "I"(offsetof(KernelArgs, mr)), + [offset_nr] "I"(offsetof(KernelArgs, nr)), [offset_dst_inc] "I"(offsetof(KernelArgs, dst_inc)), + [offset_lhs_stride] "I"(offsetof(KernelArgs, lhs_stride)), + [offset_rhs_stride] "I"(offsetof(KernelArgs, rhs_stride)), + [offset_rhs_row_bytes] "I"(offsetof(KernelArgs, rhs_row_bytes)) + : "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x24", "p0", + "p2", "p8", "p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", + "z14", "z15", "z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31", +#ifdef __ARM_STATE_ZA + "za", +#endif + "cc", "memory"); +} + +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h new file mode 100644 index 00000000..bf04a4f1 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h @@ -0,0 +1,129 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. +/// @param[in] n_idx Column index in the DST matrix. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the destination(DST) offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed +/// RHS matrix: Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) and packed. +/// Output tile: (rows x cols) = 1vl x 4vl +/// +/// Features used: sme2 +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index 5dcfb37b..83da9d74 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -20,7 +20,7 @@ // Compute args static const size_t kai_m_step = 1; -static const size_t kai_n_step = 1; +static const size_t kai_n_step = 4; // multiple of vector length // Packing args static const size_t kai_mr = 1; static const size_t kai_nr = 4; // multiple of vector length @@ -80,7 +80,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(v } size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { - return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + return kai_n_step * kai_get_sme_vector_length_u32(); } size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { @@ -100,21 +100,21 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) } size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) { - KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) { - KAI_ASSUME((n_idx % kai_n_step) == 0); + KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); return (n_idx / nr) * kai_get_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME((m_idx % kai_m_step) == 0); - KAI_ASSUME((n_idx % kai_n_step) == 0); + KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); + KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index b564996b..af7181db 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -15,6 +15,7 @@ #include #include +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" @@ -37,7 +38,7 @@ namespace kai::test { -static const std::array, 5> +static const std::array, 6> variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = {{ {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod}, @@ -49,6 +50,8 @@ static const std::array {}; @@ -347,7 +350,11 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{1, 65, 35}, // MatMulShape{1, 64, 65}, // MatMulShape{1, 63, 15}, // - MatMulShape{1, 130, 15}), + MatMulShape{1, 130, 15}, // + MatMulShape{15, 65, 35}, // + MatMulShape{16, 64, 65}, // + MatMulShape{17, 63, 15}, // + MatMulShape{20, 130, 15}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From 26707dd6189d7738d2e1f301b5d6f03636e3a7a7 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 27 May 2025 15:38:46 +0100 Subject: [PATCH 06/14] Extract asm to .S file for GEMV kernel - To enable MSVC support move the inline assembly code from kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c to a pure assembly kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S Signed-off-by: Anitha Raj --- CMakeLists.txt | 7 +- kai/ukernels/matmul/BUILD.bazel | 13 +- ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 206 ++++-------------- ...8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S | 157 +++++++++++++ 4 files changed, 214 insertions(+), 169 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b230ea1..aff431f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,7 +239,13 @@ set(KLEIDIAI_FILES_SME kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c ) +set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +) + set(KLEIDIAI_FILES_SME2 + ${KLEIDIAI_FILES_SME2_ASM} kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -251,7 +257,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index b4df8170..b49cfb91 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -157,6 +157,11 @@ SME_KERNELS = [ "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme", ] +# buildifier: keep sorted +SME2_KERNELS_ASM = [ + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", +] + # buildifier: keep sorted SME2_KERNELS = [ "imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", @@ -170,7 +175,6 @@ SME2_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa", - "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", @@ -281,6 +285,13 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS], ) +kai_c_library( + name = "sme2_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in SME2_KERNELS_ASM] + [ukernel + ".c" for ukernel in SME2_KERNELS_ASM], + cpu_uarch = kai_cpu_sme2(), + textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], +) + kai_c_library( name = "matmul", visibility = ["//visibility:public"], diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index 83da9d74..c264810e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -13,11 +13,28 @@ #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" -#include -#include - #include "kai/kai_common.h" +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + float clamp_min; + float clamp_max; + size_t dst_stride_row; + size_t m; + size_t n; + size_t k; + size_t k_internal; + size_t lhs_stride; + size_t rhs_stride; + size_t nr; + size_t rhs_row_bytes; + size_t lhs_end; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(KernelArgs* args_ptr); + // Compute args static const size_t kai_m_step = 1; static const size_t kai_n_step = 4; // multiple of vector length @@ -148,170 +165,25 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( const uint64_t rhs_row_bytes = nr * k_internal; const uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride); - /* - * x11: zero = 0 - * x15: n initialized as n - * x19: nr initialized as nr - * x21: lhs_packed initialized as lhs_packed - * x22: n_idx - * x23: k_idx - * x24: RHS block ptr - * x25: RHS end ptr - * x26: rhs_packed - * x27: dst_ptr - * x28: tmp_1 - */ - - __asm__ volatile( - - // Setup - " .inst 0xd503477f // smstart \n" - " mov x11, #0 \n" - " mov x15, %[n] \n" - " mov x19, %[nr] \n" - " mov x21, %[lhs_packed] \n" - " ptrue p0.b \n" - " .inst 0x25207810 // ptrue pn8.b \n" - // predicate to load nr words for the RHS sums and scaling factors (should be exactly all true) - " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n" - " dup z30.s, %w[scalar_min] \n" - " dup z31.s, %w[scalar_max] \n" - - // lhs matrix row loop - "1: \n" - // Reset rhs matrix ptr - " mov x26, %[rhs_packed] \n" - // Reset dst_ptr to dst of next GEMV result - " mov x27, %[dst_ptr] \n" - // Reset n index - " mov x22, #0 \n" - // whilelt pn12.s, x22, %[n], vlx4 - " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" - - // rhs matrix row loop (transposed so theoretical columns) - "2: \n" - - // Reset rhs block ptr to start of row - " mov x24, x26 \n" - " add x25, x26, %[rhs_row_bytes] \n" - " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" - " addvl x28, x24, #4 \n" - " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" - " addvl x28, x28, #4 \n" - " .inst 0x25396795 // whilelt pn13.b, x28, x25, vlx4 \n" - " addvl x28, x28, #4 \n" - " .inst 0x25396796 // whilelt pn14.b, x28, x25, vlx4 \n" - " mov x23, #0 \n" - " whilelt p1.b, x23, %[k_internal] \n" - // Zero for sdot accumulation in inner loop - " .inst 0xc00800ff // zero {za} \n" - - // before k loop - "3: \n" - - // Load lhs - " ld1rqb { z0.b }, p1/z , [x21, x23] \n" - - // Load w - " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n" - " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n" - - " .inst 0xc150f220 // sdot za.s[w11,0, vgx4], {z16.b-z19.b}, z0.b[0] \n" - " .inst 0xc150f6a0 // sdot za.s[w11,0, vgx4], {z20.b-z23.b}, z0.b[1] \n" - - " .inst 0xa0429710 // ld1b { z16.b - z19.b }, pn13/z, [x24,#0x8, mul vl] \n" - " .inst 0xa0439b14 // ld1b { z20.b - z23.b }, pn14/z, [x24,#0xC, mul vl] \n" - " .inst 0xc150fa20 // sdot za.s[w11,0, vgx4], {z16.b-z19.b}, z0.b[2] \n" - " .inst 0xc150fea0 // sdot za.s[w11,0, vgx4], {z20.b-z23.b}, z0.b[3] \n" - - // End K block loop - " addvl x24, x24, #16 \n" - " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" - " addvl x28, x24, #4 \n" - " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" - " addvl x28, x28, #4 \n" - " .inst 0x25396795 // whilelt pn13.b, x28, x25, vlx4 \n" - " addvl x28, x28, #4 \n" - " .inst 0x25396796 //whilelt pn14.b, x28, x25, vlx4 \n" - " add x23, x23, #16 \n" - " whilelt p1.b, x23, %[k_internal] \n" - " b.first 3b \n" - - // Finish of accumulators with scaling factors and zero points - - // Load lhs zero point - " add x28, x21, %[k_internal] \n" - " ld1rw { z2.s }, p0/z , [x28] \n" - // Load lhs scaling factor - " ld1rw { z3.s }, p0/z , [x28, #4] \n" - // Load rhs sums - " add x28, x26, %[rhs_row_bytes] \n" - " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n" - // Load rhs scaling factors - " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n" - // Load biases - " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n" - - // Get accumulated value out of ZA - " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n" - - // za contains a * w, which needs to be done + z * wsum -> smla - // zero point * rhs row sum - " mla z4.s, p0/m, z20.s, z2.s \n" - " mla z5.s, p0/m, z21.s, z2.s \n" - " mla z6.s, p0/m, z22.s, z2.s \n" - " mla z7.s, p0/m, z23.s, z2.s \n" - - // Convert to float - " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n" - - // lhs scaling factor * rhs scaling factor - " fmul z24.s, z24.s, z3.s \n" - " fmul z25.s, z25.s, z3.s \n" - " fmul z26.s, z26.s, z3.s \n" - " fmul z27.s, z27.s, z3.s \n" - - // Bias + combined scaling factor * combined accumulator - " fmla z12.s, p0/m, z24.s, z4.s \n" - " fmla z13.s, p0/m, z25.s, z5.s \n" - " fmla z14.s, p0/m, z26.s, z6.s \n" - " fmla z15.s, p0/m, z27.s, z7.s \n" - - // Clamp - " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n" - - // Store - " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n" - - // End rhs row loop - " add x26, x26, %[rhs_stride] \n" - // nr == svlb - " addvl x22, x22, #1 \n" - // whilelt pn12.s, x22, %[n], vlx4 - " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" - " b.lt 2b \n" - - // End lhs row loop - " add %[dst_ptr], %[dst_ptr], %[dst_stride_row] \n" - " add x21, x21, %[lhs_stride] \n" - " cmp x21, %[lhs_end_ptr] \n" - " b.lt 1b \n" - - " .inst 0xd503467f // smstop \n" - - : [dst_ptr] "+r"(dst) - : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed), [rhs_packed] "r"(rhs_packed), - [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min), [scalar_max] "r"(scalar_max), - [k_internal] "r"(k_internal), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), [nr] "r"(nr), - [rhs_row_bytes] "r"(rhs_row_bytes), [lhs_end_ptr] "r"(lhs_end_ptr) - : "x11", "x15", "x19", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9", "p10", - "p11", "p12", "p13", "p14", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", - "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", - "z28", "z29", "z30", "z31", -#ifdef __ARM_STATE_ZA - "za", -#endif - "memory", "cc"); + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_max = scalar_max; + args.clamp_min = scalar_min; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.k = k; + args.k_internal = k_internal; + args.lhs_stride = lhs_stride; + args.rhs_stride = rhs_stride; + args.nr = nr; + args.rhs_row_bytes = rhs_row_bytes; + args.lhs_end = lhs_end_ptr; + + kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S new file mode 100644 index 00000000..bf98f92f --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S @@ -0,0 +1,157 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + stp x29, x30, [sp, -176]! + mov x29, sp + stp x19, x20, [sp, 16] + stp x21, x22, [sp, 32] + stp x23, x24, [sp, 48] + stp x25, x26, [sp, 64] + stp x27, x28, [sp, 80] + stp d8, d9, [sp, 96] + stp d10, d11, [sp, 112] + stp d12, d13, [sp, 128] + stp d14, d15, [sp, 144] + KAI_ASM_INST(0xd503477f) // smstart + ldr x16, [x0, #0] + mov x11, #0 + ldr x15, [x0, #0x30] + ldr x19, [x0, #0x58] + ldr x21, [x0, #0x8] + ptrue p0.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b + KAI_ASM_INST(0x25b36571) // whilelt pn9.s, x11, x19, vlx4 + ld1rw { z30.s }, p0/Z, [x0, #0x18] + ld1rw { z31.s }, p0/Z, [x0, #0x1c] + ldr x18, [x0, #0x40] +KAI_ASM_LABEL(label_1) // Row Loop + ldr x17, [x0, #0x60] + ldr x26, [x0, #0x10] + mov x27, x16 + mov x22, #0 + KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 +KAI_ASM_LABEL(label_2) // Column Loop + mov x24, x26 + add x25, x26, x17 + KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 + addvl x28, x24, #4 + KAI_ASM_INST(0x25396793) // whilelt pn11.b, x28, x25, vlx4 + addvl x28, x28, #4 + KAI_ASM_INST(0x25396795) // whilelt pn13.b, x28, x25, vlx4 + addvl x28, x28, #4 + KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 + mov x23, #0 + whilelt p1.b, x23, x18 + KAI_ASM_INST(0xc00800ff) // zero {za} +KAI_ASM_LABEL(label_3) // Block Loop + ld1rqb { z0.b }, p1/z, [x21, x23] + KAI_ASM_INST(0xa0408b10) // ld1b { z16.b - z19.b }, pn10/z, [x24] + KAI_ASM_INST(0xa0418f14) // ld1b { z20.b - z23.b }, pn11/z, [x24, #0x4, mul vl] + KAI_ASM_INST(0xc150f220) // sdot za.s[w11, 0, vgx4], { z16.b - z19.b }, z0.b[0] + KAI_ASM_INST(0xc150f6a0) // sdot za.s[w11, 0, vgx4], { z20.b - z23.b }, z0.b[1] + KAI_ASM_INST(0xa0429710) // ld1b { z16.b - z19.b }, pn13/z, [x24, #0x8, mul vl] + KAI_ASM_INST(0xa0439b14) // ld1b { z20.b - z23.b }, pn14/z, [x24, #0xc, mul vl] + KAI_ASM_INST(0xc150fa20) // sdot za.s[w11, 0, vgx4], { z16.b - z19.b }, z0.b[2] + KAI_ASM_INST(0xc150fea0) // sdot za.s[w11, 0, vgx4], { z20.b - z23.b }, z0.b[3] + addvl x24, x24, #16 + KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 + addvl x28, x24, #4 + KAI_ASM_INST(0x25396793) // whilelt pn11.b, x28, x25, vlx4 + addvl x28, x28, #4 + KAI_ASM_INST(0x25396795) // whilelt pn13.b, x28, x25, vlx4 + addvl x28, x28, #4 + KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 + add x23, x23, #16 + whilelt p1.b, x23, x18 + b.first label_3 + add x28, x21, x18 + ld1rw { z2.s }, p0/z, [x28] + ld1rw { z3.s }, p0/z, [x28, #4] + add x28, x26, x17 + KAI_ASM_INST(0xa040c794) // ld1w { z20.s - z23.s }, pn9/z, [x28] + KAI_ASM_INST(0xa041c798) // ld1w { z24.s - z27.s }, pn9/z, [x28, #0x4, mul vl] + KAI_ASM_INST(0xa042c78c) // ld1w { z12.s - z15.s }, pn9/z, [x28, #0x8, mul vl] + KAI_ASM_INST(0xc0066c04) // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] + mla z4.s, p0/m, z20.s, z2.s + mla z5.s, p0/m, z21.s, z2.s + mla z6.s, p0/m, z22.s, z2.s + mla z7.s, p0/m, z23.s, z2.s + KAI_ASM_INST(0xc132e084) // scvtf { z4.s - z7.s }, { z4.s - z7.s } + fmul z24.s, z24.s, z3.s + fmul z25.s, z25.s, z3.s + fmul z26.s, z26.s, z3.s + fmul z27.s, z27.s, z3.s + fmla z12.s, p0/m, z24.s, z4.s + fmla z13.s, p0/m, z25.s, z5.s + fmla z14.s, p0/m, z26.s, z6.s + fmla z15.s, p0/m, z27.s, z7.s + KAI_ASM_INST(0xc1bfcbcc) // fclamp { z12.s - z15.s }, z30.s, z31.s + KAI_ASM_INST(0xa036d36c) // st1w { z12.s - z15.s }, pn12, [x27, x22, lsl #2] + ldr x20, [x0, #0x50] + add x26, x26, x20 + addvl x22, x22, #1 + KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 + b.lt label_2 + ldr x20, [x0, #0x20] + add x16, x16, x20 + ldr x20, [x0, #0x48] + add x21, x21, x20 + ldr x20, [x0, #0x68] + cmp x21, x20 + b.lt label_1 + KAI_ASM_INST(0xd503467f) // smstop + ldp d14, d15, [sp, 144] + ldp d12, d13, [sp, 128] + ldp d10, d11, [sp, 112] + ldp d8, d9, [sp, 96] + ldp x27, x28, [sp, 80] + ldp x25, x26, [sp, 64] + ldp x23, x24, [sp, 48] + ldp x21, x22, [sp, 32] + ldp x19, x20, [sp, 16] + ldp x29, x30, [sp], 176 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + + KAI_ASM_END -- GitLab From b46ed7a9bdccd1a097b6c6ccb552c7cf4bc9e901 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 28 May 2025 08:03:21 +0100 Subject: [PATCH 07/14] Extract inline asm from GEMM kernel to .S file - To enable MSVC support move the inline assembly code from kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c to a pure assembly kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S - Minor fixes to kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c to improve readability Signed-off-by: Anitha Raj --- CMakeLists.txt | 3 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c | 238 ++++-------------- ...1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S | 175 +++++++++++++ ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 30 +-- ...8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S | 38 +-- 6 files changed, 260 insertions(+), 226 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index aff431f1..ef7d0dc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,6 +240,8 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2_ASM + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S ) @@ -255,7 +257,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index b49cfb91..4382f370 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -159,6 +159,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS_ASM = [ + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", ] @@ -174,7 +175,6 @@ SME2_KERNELS = [ "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", - "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c index ccb3aef7..67964893 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -13,11 +13,28 @@ #include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h" -#include -#include - #include "kai/kai_common.h" +typedef struct { + float* dst; // 0x00 + const void* lhs_packed; // 0x08 + const void* rhs_packed; // 0x10 + size_t dst_stride_row; // 0x18 + size_t m; // 0x20 + size_t n; // 0x28 + size_t lhs_stride; // 0x30 + size_t rhs_stride; // 0x38 + size_t mr; // 0x40 + size_t nr; // 0x48 + size_t rhs_row_bytes; // 0x50 + size_t m_blk; // 0x58 + size_t dst_inc; // 0x60 + float clamp_min; // 0x68 + float clamp_max; // 0x6C +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(KernelArgs* args_ptr); + // Compute args static const size_t kai_m_step = 1; // multiple of vector length static const size_t kai_n_step = 4; // multiple of vector length @@ -126,198 +143,39 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_ } void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( - size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, - float* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { KAI_ASSERT(dst_stride_col == sizeof(float)); KAI_ASSERT(n > 0); KAI_ASSERT(m > 0); - const float scalar_bounds[2] = {scalar_min, scalar_max}; - - typedef struct { - size_t lhs_stride; - size_t rhs_stride; - size_t mr; - size_t nr; - size_t m_blk; - size_t dst_inc; - size_t rhs_row_bytes; - } KernelArgs; - - KernelArgs ka; - - // Constants - ka.mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); - ka.nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); - ka.lhs_stride = kai_get_lhs_packed_stride(k); - ka.rhs_stride = kai_get_rhs_packed_stride(k); + KernelArgs args; const size_t k_internal = kai_k_roundedup(k); - ka.m_blk = k_internal * ka.mr; - ka.dst_inc = ka.mr * dst_stride_row; - ka.rhs_row_bytes = ka.nr * k_internal; - - __asm__ volatile( - " .inst 0xd503477f //smstart \n" - " mov x19, %[dst] \n" - " mov x20, %[lhs] \n" - " cntw x7 \n" - - " ptrue p2.b \n" - " ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n" - - " ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n" - - // M loop head - " mov x12, %[m] \n" - " .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n" - "1: \n" - " mov x8, %[rhs] \n" - " mov x9, x19 \n" - " mov x13, %[n] \n" - " cmp x7, x12 \n" - " csel x16, x7, x12, lt \n" - " lsl x16, x16, #2 \n" - - // N loop head - " ldr x24, [%x[args_ptr], %[offset_rhs_row_bytes]] \n" - " add x24, x24, x8 \n" - " mov x11, x8 \n" - " .inst 0x25784570 // whilelt pn8.h, x11, x24, vlx2 \n" - " addvl x11, x8, #2 \n" - " .inst 0x25784572 // whilelt pn10.h, x11, x24, vlx2 \n" - - "2: \n" - " mov x10, x20 \n" - " mov x11, x8 \n" - " mov x17, x9 \n" - " .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n" - - // K loop - " .inst 0xc00800ff //zero {za} \n" - " ldr x24, [%x[args_ptr], %[offset_m_blk]] \n" - " add x14, x10, x24 \n" - - "3: \n" - " .inst 0xa540a144 // ld1w { z4.s }, p0/z, [x10] \n" - " .inst 0x042a502a // addvl x10, x10, #1 \n" - - " .inst 0xa0402168 //ld1h { z8.h - z9.h }, pn8/z, [x11] \n" - " .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n" - " .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n" - - " .inst 0xa041296a //ld1h { z10.h - z11.h }, pn8/z, [x11, #0x2, mul vl] \n" - " .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n" - " .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n" - - " .inst 0x042b508b // addvl x11, x11, #4 \n" - " cmp x10, x14 \n" - " b.lt 3b \n" - - // RHS row sum, scale factor & bias - " .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n" - " .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n" - " .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n" - " .inst 0x042b518b //addvl x11, x11, #12 \n" - " .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" - - // Store loop - " mov x14, #0 \n" - " addvl x15, x10, #1 \n" - - "4: \n" - // Load LHS Row-offset & SF - " ld1rw {z16.s}, p2/z, [x10] \n" - " ld1rw {z17.s}, p2/z, [x15] \n" - " add x10, x10, #4 \n" - " add x15, x15, #4 \n" - " scvtf z16.s, p2/m, z16.s \n" - - // offset x Row-sum - " fmul z24.s, z16.s, z0.s \n" - " fmul z25.s, z16.s, z1.s \n" - " fmul z26.s, z16.s, z2.s \n" - " fmul z27.s, z16.s, z3.s \n" - - // Scaling factors - " fmul z20.s, z17.s, z4.s \n" - " fmul z21.s, z17.s, z5.s \n" - " fmul z22.s, z17.s, z6.s \n" - " fmul z23.s, z17.s, z7.s \n" - - // Result = offset x Row-sum x SFs - " fmul z24.s, z24.s, z20.s \n" - " fmul z25.s, z25.s, z21.s \n" - " fmul z26.s, z26.s, z22.s \n" - " fmul z27.s, z27.s, z23.s \n" - - // Load inner accumulation & convert - " .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n" - " .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n" - - // Result += iacc x SF - " fmla z24.s, p2/m, z20.s, z12.s \n" - " fmla z25.s, p2/m, z21.s, z13.s \n" - " fmla z26.s, p2/m, z22.s, z14.s \n" - " fmla z27.s, p2/m, z23.s, z15.s \n" - - // Add the bias - " fadd z24.s, p2/m, z24.s, z8.s \n" - " fadd z25.s, p2/m, z25.s, z9.s \n" - " fadd z26.s, p2/m, z26.s, z10.s \n" - " fadd z27.s, p2/m, z27.s, z11.s \n" - - // CLAMP and store - " .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n" - " .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n" - - " add x17, x17, %[dst_stride_row] \n" - " add x14, x14, #4 \n" - " cmp x14, x16 \n" - " b.lt 4b \n" - - // N loop tail - " ldr x24, [%x[args_ptr], %[offset_rhs_stride]] \n" - " add x8, x8, x24 \n" - " .inst 0x04295089 // addvl x9, x9, #4 \n" - " ldr x24, [%x[args_ptr], %[offset_rhs_row_bytes]] \n" - " add x24, x24, x8 \n" - " mov x11, x8 \n" - " .inst 0x25784570 // whilelt pn8.h, x11, x24, vlx2\n" - " addvl x11, x8, #2 \n" - " .inst 0x25784572 // whilelt pn10.h, x11, x24, vlx2 \n" - " ldr x24, [%x[args_ptr], %[offset_nr]] \n" - " sub x13, x13, x24 \n" - " cmp xzr, x13 \n" - " b.mi 2b \n" - - // M loop tail - " ldr x24, [%x[args_ptr], %[offset_lhs_stride]] \n" - " add x20, x20, x24 \n" - " ldr x24, [%x[args_ptr], %[offset_dst_inc]] \n" - " add x19, x19, x24 \n" - " ldr x24, [%x[args_ptr], %[offset_mr]] \n" - " sub x12, x12, x24 \n" - " whilelt p0.s, xzr, x12 \n" - " b.mi 1b \n" - - "5: \n" - " .inst 0xd503467f //smstop \n" - : - : [m] "r"(m), [n] "r"(n), [k] "r"(k), [dst] "r"(dst), [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), - [dst_stride_row] "r"(dst_stride_row), [scalar_bounds] "r"(scalar_bounds), [args_ptr] "r"(&ka), - [offset_m_blk] "I"(offsetof(KernelArgs, m_blk)), [offset_mr] "I"(offsetof(KernelArgs, mr)), - [offset_nr] "I"(offsetof(KernelArgs, nr)), [offset_dst_inc] "I"(offsetof(KernelArgs, dst_inc)), - [offset_lhs_stride] "I"(offsetof(KernelArgs, lhs_stride)), - [offset_rhs_stride] "I"(offsetof(KernelArgs, rhs_stride)), - [offset_rhs_row_bytes] "I"(offsetof(KernelArgs, rhs_row_bytes)) - : "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "x24", "p0", - "p2", "p8", "p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", - "z14", "z15", "z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31", -#ifdef __ARM_STATE_ZA - "za", -#endif - "cc", "memory"); + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.lhs_stride = kai_get_lhs_packed_stride(k); + args.rhs_stride = kai_get_rhs_packed_stride(k); + args.mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + args.nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + args.rhs_row_bytes = args.nr * k_internal; + args.m_blk = args.mr * k_internal; + args.dst_inc = args.mr * dst_stride_row; + args.clamp_min = scalar_min; + args.clamp_max = scalar_max; + + kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(&args); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S new file mode 100644 index 00000000..b0ee303f --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S @@ -0,0 +1,175 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + stp x29, x30, [sp, -176]! + mov x29, sp + stp x19, x20, [sp, 16] + stp x21, x22, [sp, 32] + stp x23, x24, [sp, 48] + stp x25, x26, [sp, 64] + stp x27, x28, [sp, 80] + stp d8, d9, [sp, 96] + stp d10, d11, [sp, 112] + stp d12, d13, [sp, 128] + stp d14, d15, [sp, 144] + KAI_ASM_INST(0xd503477f) // smstart + ldr x19, [x0, #0] // dst + ldr x20, [x0, #8] // lhs_packed + cntw x7 + ptrue p2.b + ld1rw {z30.s}, p2/Z, [x0, #0x68] // clamp_min + ld1rw {z31.s}, p2/Z, [x0, #0x6C] // clamp_max + ldr x12, [x0, #0x20] // m + KAI_ASM_INST(0x25ac17e0) // whilelt p0.s, xzr, x12 +KAI_ASM_LABEL(label_1) // Row Loop + ldr x8, [x0, #0x10] // rhs_packed + mov x9, x19 + ldr x13, [x0, #0x28] // n + cmp x7, x12 + csel x16, x7, x12, lt + lsl x16, x16, #2 + ldr x24, [x0, #0x50] // rhs_row_bytes + add x24, x24, x8 + mov x11, x8 + KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 + addvl x11, x8, #2 + KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 +KAI_ASM_LABEL(label_2) // Column Loop + mov x10, x20 + mov x11, x8 + mov x17, x9 + KAI_ASM_INST(0x25ad67f1) // whilelt pn9.s, xzr, x13, vlx4 + KAI_ASM_INST(0xc00800ff) // zero {za} + ldr x24, [x0, #0x58] // m_blk + add x14, x10, x24 +KAI_ASM_LABEL(label_3) // Block Loop + KAI_ASM_INST(0xa540a144) // ld1w { z4.s }, p0/z, [x10] + KAI_ASM_INST(0x042a502a) // addvl x10, x10, #1 + KAI_ASM_INST(0xa0402168) // ld1h { z8.h - z9.h }, pn8/z, [x11] + KAI_ASM_INST(0xa0884880) // smopa za0.s, p2/m, p2/m, z4.b, z8.b + KAI_ASM_INST(0xa0894881) // smopa za1.s, p2/m, p2/m, z4.b, z9.b + KAI_ASM_INST(0xa041296a) // ld1h { z10.h - z11.h }, pn8/z, [x11, #0x2, mul vl] + KAI_ASM_INST(0xa08a4882) // smopa za2.s, p2/m, p2/m, z4.b, z10.b + KAI_ASM_INST(0xa08b4883) // smopa za3.s, p2/m, p2/m, z4.b, z11.b + KAI_ASM_INST(0x042b508b) // addvl x11, x11, #4 + cmp x10, x14 + b.lt label_3 + KAI_ASM_INST(0xa040c560) // ld1w { z0.s-z3.s }, pn9/z, [x11] + KAI_ASM_INST(0xa041c564) // ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] + KAI_ASM_INST(0xa042c568) // ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl] + KAI_ASM_INST(0x042b518b) // addvl x11, x11, #12 + KAI_ASM_INST(0xc132e000) // scvtf { z0.s-z3.s }, { z0.s-z3.s } + mov x14, #0 + addvl x15, x10, #1 +KAI_ASM_LABEL(label_4) + ld1rw {z16.s}, p2/z, [x10] + ld1rw {z17.s}, p2/z, [x15] + add x10, x10, #4 + add x15, x15, #4 + scvtf z16.s, p2/m, z16.s + fmul z24.s, z16.s, z0.s + fmul z25.s, z16.s, z1.s + fmul z26.s, z16.s, z2.s + fmul z27.s, z16.s, z3.s + fmul z20.s, z17.s, z4.s + fmul z21.s, z17.s, z5.s + fmul z22.s, z17.s, z6.s + fmul z23.s, z17.s, z7.s + fmul z24.s, z24.s, z20.s + fmul z25.s, z25.s, z21.s + fmul z26.s, z26.s, z22.s + fmul z27.s, z27.s, z23.s + KAI_ASM_INST(0xc006440c) // mova { z12.b-z15.b }, za0h.b[w14, 0:3] + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } + fmla z24.s, p2/m, z20.s, z12.s + fmla z25.s, p2/m, z21.s, z13.s + fmla z26.s, p2/m, z22.s, z14.s + fmla z27.s, p2/m, z23.s, z15.s + fadd z24.s, p2/m, z24.s, z8.s + fadd z25.s, p2/m, z25.s, z9.s + fadd z26.s, p2/m, z26.s, z10.s + fadd z27.s, p2/m, z27.s, z11.s + KAI_ASM_INST(0xc1bfcbd8) // fclamp { z24.s-z27.s }, z30.s, z31.s + KAI_ASM_INST(0xa060c638) // st1w { z24.s-z27.s }, pn9, [x17] + ldr x24, [x0, #0x18] // dst_stride_row + add x17, x17, x24 + add x14, x14, #4 + cmp x14, x16 + b.lt label_4 + ldr x24, [x0, #0x38] // rhs_stride + add x8, x8, x24 + KAI_ASM_INST(0x04295089) // addvl x9, x9, #4 + ldr x24, [x0, #0x50] // rhs_row_bytes + add x24, x24, x8 + mov x11, x8 + KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 + addvl x11, x8, #2 + KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 + ldr x24, [x0, #0x48] // nr + sub x13, x13, x24 + cmp xzr, x13 + b.mi label_2 + ldr x24, [x0, #0x30] // lhs_stride + add x20, x20, x24 + ldr x24, [x0, #0x60] // dst_inc + add x19, x19, x24 + ldr x24, [x0, #0x40] // mr + sub x12, x12, x24 + whilelt p0.s, xzr, x12 + b.mi label_1 + KAI_ASM_INST(0xd503467f) // smstop + ldp d14, d15, [sp, 144] + ldp d12, d13, [sp, 128] + ldp d10, d11, [sp, 112] + ldp d8, d9, [sp, 96] + ldp x27, x28, [sp, 80] + ldp x25, x26, [sp, 64] + ldp x23, x24, [sp, 48] + ldp x21, x22, [sp, 32] + ldp x19, x20, [sp, 16] + ldp x29, x30, [sp], 176 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index c264810e..8442a4c7 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -16,21 +16,21 @@ #include "kai/kai_common.h" typedef struct { - float* dst; - const void* lhs_packed; - const void* rhs_packed; - float clamp_min; - float clamp_max; - size_t dst_stride_row; - size_t m; - size_t n; - size_t k; - size_t k_internal; - size_t lhs_stride; - size_t rhs_stride; - size_t nr; - size_t rhs_row_bytes; - size_t lhs_end; + float* dst; // 0x00 + const void* lhs_packed; // 0x08 + const void* rhs_packed; // 0x10 + size_t dst_stride_row; // 0x18 + size_t m; // 0x20 + size_t n; // 0x28 + size_t k; // 0x30 + size_t k_internal; // 0x38 + size_t lhs_stride; // 0x40 + size_t rhs_stride; // 0x48 + size_t nr; // 0x50 + size_t rhs_row_bytes; // 0x58 + size_t lhs_end; // 0x60 + float clamp_min; // 0x68 + float clamp_max; // 0x6C } KernelArgs; void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(KernelArgs* args_ptr); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S index bf98f92f..39f91da5 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S @@ -54,24 +54,24 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl stp d12, d13, [sp, 128] stp d14, d15, [sp, 144] KAI_ASM_INST(0xd503477f) // smstart - ldr x16, [x0, #0] + ldr x16, [x0, #0] // dst mov x11, #0 - ldr x15, [x0, #0x30] - ldr x19, [x0, #0x58] - ldr x21, [x0, #0x8] + ldr x15, [x0, #0x28] // n + ldr x19, [x0, #0x50] // nr + ldr x21, [x0, #0x8] // lhs_packed ptrue p0.b KAI_ASM_INST(0x25207810) // ptrue pn8.b KAI_ASM_INST(0x25b36571) // whilelt pn9.s, x11, x19, vlx4 - ld1rw { z30.s }, p0/Z, [x0, #0x18] - ld1rw { z31.s }, p0/Z, [x0, #0x1c] - ldr x18, [x0, #0x40] -KAI_ASM_LABEL(label_1) // Row Loop - ldr x17, [x0, #0x60] - ldr x26, [x0, #0x10] + ld1rw { z30.s }, p0/Z, [x0, #0x68] // clamp_min + ld1rw { z31.s }, p0/Z, [x0, #0x6c] // clamp_max + ldr x14, [x0, #0x38] // k_internal +KAI_ASM_LABEL(label_1) // Row Loop + ldr x17, [x0, #0x58] // rhs_row_bytes + ldr x26, [x0, #0x10] // rhs_packed mov x27, x16 mov x22, #0 KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 -KAI_ASM_LABEL(label_2) // Column Loop +KAI_ASM_LABEL(label_2) // Column Loop mov x24, x26 add x25, x26, x17 KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 @@ -82,9 +82,9 @@ KAI_ASM_LABEL(label_2) // Column Loop addvl x28, x28, #4 KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 mov x23, #0 - whilelt p1.b, x23, x18 + whilelt p1.b, x23, x14 KAI_ASM_INST(0xc00800ff) // zero {za} -KAI_ASM_LABEL(label_3) // Block Loop +KAI_ASM_LABEL(label_3) // Block Loop ld1rqb { z0.b }, p1/z, [x21, x23] KAI_ASM_INST(0xa0408b10) // ld1b { z16.b - z19.b }, pn10/z, [x24] KAI_ASM_INST(0xa0418f14) // ld1b { z20.b - z23.b }, pn11/z, [x24, #0x4, mul vl] @@ -103,9 +103,9 @@ KAI_ASM_LABEL(label_3) // Block Loop addvl x28, x28, #4 KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 add x23, x23, #16 - whilelt p1.b, x23, x18 + whilelt p1.b, x23, x14 b.first label_3 - add x28, x21, x18 + add x28, x21, x14 ld1rw { z2.s }, p0/z, [x28] ld1rw { z3.s }, p0/z, [x28, #4] add x28, x26, x17 @@ -128,16 +128,16 @@ KAI_ASM_LABEL(label_3) // Block Loop fmla z15.s, p0/m, z27.s, z7.s KAI_ASM_INST(0xc1bfcbcc) // fclamp { z12.s - z15.s }, z30.s, z31.s KAI_ASM_INST(0xa036d36c) // st1w { z12.s - z15.s }, pn12, [x27, x22, lsl #2] - ldr x20, [x0, #0x50] + ldr x20, [x0, #0x48] // rhs_stride add x26, x26, x20 addvl x22, x22, #1 KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 b.lt label_2 - ldr x20, [x0, #0x20] + ldr x20, [x0, #0x18] // dst_stride_row add x16, x16, x20 - ldr x20, [x0, #0x48] + ldr x20, [x0, #0x40] // lhs_stride add x21, x21, x20 - ldr x20, [x0, #0x68] + ldr x20, [x0, #0x60] // lhs_end cmp x21, x20 b.lt label_1 KAI_ASM_INST(0xd503467f) // smstop -- GitLab From 584056f94bca86d4b539be448f5541cc690d9375 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 28 May 2025 10:23:55 +0100 Subject: [PATCH 08/14] Fix Bazel build and clang tidy error Signed-off-by: Anitha Raj --- kai/ukernels/matmul/BUILD.bazel | 1 + ...2_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c | 2 ++ ...p_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 14 ++++++++------ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 4382f370..39ba214e 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -310,6 +310,7 @@ kai_c_library( ":neon_impl_asm", ":scalar_impl", ":sme2_impl", + ":sme2_impl_asm", ":sme_impl", ], ) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c index 67964893..dd5c21d9 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -13,6 +13,8 @@ #include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h" +#include + #include "kai/kai_common.h" typedef struct { diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index 8442a4c7..f62b63d3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -13,6 +13,8 @@ #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" +#include + #include "kai/kai_common.h" typedef struct { @@ -157,13 +159,13 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( return; } - const uint64_t k_internal = kai_k_roundedup(k); - const uint64_t lhs_stride = kai_get_lhs_packed_stride(k); - const uint64_t rhs_stride = kai_get_rhs_packed_stride(k); - const uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + const size_t k_internal = kai_k_roundedup(k); + const size_t lhs_stride = kai_get_lhs_packed_stride(k); + const size_t rhs_stride = kai_get_rhs_packed_stride(k); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); - const uint64_t rhs_row_bytes = nr * k_internal; - const uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride); + const size_t rhs_row_bytes = nr * k_internal; + const size_t lhs_end_ptr = ((size_t)lhs_packed) + (m * lhs_stride); KernelArgs args; -- GitLab From 00babf100897117063bd765273a86b95c81ff20d Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 28 May 2025 15:23:36 +0100 Subject: [PATCH 09/14] Calculate nr and mr in asm instead of passing in args Signed-off-by: Anitha Raj --- ...8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c | 23 +++++++++---------- ...1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S | 16 ++++++------- ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 10 ++++---- ...8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S | 10 ++++---- 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c index dd5c21d9..6cbc9e6c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -26,13 +26,11 @@ typedef struct { size_t n; // 0x28 size_t lhs_stride; // 0x30 size_t rhs_stride; // 0x38 - size_t mr; // 0x40 - size_t nr; // 0x48 - size_t rhs_row_bytes; // 0x50 - size_t m_blk; // 0x58 - size_t dst_inc; // 0x60 - float clamp_min; // 0x68 - float clamp_max; // 0x6C + size_t rhs_row_bytes; // 0x40 + size_t m_blk; // 0x48 + size_t dst_inc; // 0x50 + float clamp_min; // 0x58 + float clamp_max; // 0x5C } KernelArgs; void kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(KernelArgs* args_ptr); @@ -159,6 +157,9 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( KAI_ASSERT(n > 0); KAI_ASSERT(m > 0); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + KernelArgs args; const size_t k_internal = kai_k_roundedup(k); args.dst = dst; @@ -169,11 +170,9 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( args.n = n; args.lhs_stride = kai_get_lhs_packed_stride(k); args.rhs_stride = kai_get_rhs_packed_stride(k); - args.mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); - args.nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); - args.rhs_row_bytes = args.nr * k_internal; - args.m_blk = args.mr * k_internal; - args.dst_inc = args.mr * dst_stride_row; + args.rhs_row_bytes = nr * k_internal; + args.m_blk = mr * k_internal; + args.dst_inc = mr * dst_stride_row; args.clamp_min = scalar_min; args.clamp_max = scalar_max; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S index b0ee303f..53e5e46e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S @@ -58,8 +58,8 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vl ldr x20, [x0, #8] // lhs_packed cntw x7 ptrue p2.b - ld1rw {z30.s}, p2/Z, [x0, #0x68] // clamp_min - ld1rw {z31.s}, p2/Z, [x0, #0x6C] // clamp_max + ld1rw {z30.s}, p2/Z, [x0, #0x58] // clamp_min + ld1rw {z31.s}, p2/Z, [x0, #0x5C] // clamp_max ldr x12, [x0, #0x20] // m KAI_ASM_INST(0x25ac17e0) // whilelt p0.s, xzr, x12 KAI_ASM_LABEL(label_1) // Row Loop @@ -69,7 +69,7 @@ KAI_ASM_LABEL(label_1) // Row Loop cmp x7, x12 csel x16, x7, x12, lt lsl x16, x16, #2 - ldr x24, [x0, #0x50] // rhs_row_bytes + ldr x24, [x0, #0x40] // rhs_row_bytes add x24, x24, x8 mov x11, x8 KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 @@ -81,7 +81,7 @@ KAI_ASM_LABEL(label_2) // Column Loop mov x17, x9 KAI_ASM_INST(0x25ad67f1) // whilelt pn9.s, xzr, x13, vlx4 KAI_ASM_INST(0xc00800ff) // zero {za} - ldr x24, [x0, #0x58] // m_blk + ldr x24, [x0, #0x48] // m_blk add x14, x10, x24 KAI_ASM_LABEL(label_3) // Block Loop KAI_ASM_INST(0xa540a144) // ld1w { z4.s }, p0/z, [x10] @@ -140,21 +140,21 @@ KAI_ASM_LABEL(label_4) ldr x24, [x0, #0x38] // rhs_stride add x8, x8, x24 KAI_ASM_INST(0x04295089) // addvl x9, x9, #4 - ldr x24, [x0, #0x50] // rhs_row_bytes + ldr x24, [x0, #0x40] // rhs_row_bytes add x24, x24, x8 mov x11, x8 KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 addvl x11, x8, #2 KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 - ldr x24, [x0, #0x48] // nr + cntb x24 // nr sub x13, x13, x24 cmp xzr, x13 b.mi label_2 ldr x24, [x0, #0x30] // lhs_stride add x20, x20, x24 - ldr x24, [x0, #0x60] // dst_inc + ldr x24, [x0, #0x50] // dst_inc add x19, x19, x24 - ldr x24, [x0, #0x40] // mr + cntw x24 // mr sub x12, x12, x24 whilelt p0.s, xzr, x12 b.mi label_1 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index f62b63d3..56e711bd 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -28,11 +28,10 @@ typedef struct { size_t k_internal; // 0x38 size_t lhs_stride; // 0x40 size_t rhs_stride; // 0x48 - size_t nr; // 0x50 - size_t rhs_row_bytes; // 0x58 - size_t lhs_end; // 0x60 - float clamp_min; // 0x68 - float clamp_max; // 0x6C + size_t rhs_row_bytes; // 0x50 + size_t lhs_end; // 0x58 + float clamp_min; // 0x60 + float clamp_max; // 0x64 } KernelArgs; void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(KernelArgs* args_ptr); @@ -181,7 +180,6 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( args.k_internal = k_internal; args.lhs_stride = lhs_stride; args.rhs_stride = rhs_stride; - args.nr = nr; args.rhs_row_bytes = rhs_row_bytes; args.lhs_end = lhs_end_ptr; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S index 39f91da5..603c6648 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S @@ -57,16 +57,16 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl ldr x16, [x0, #0] // dst mov x11, #0 ldr x15, [x0, #0x28] // n - ldr x19, [x0, #0x50] // nr + cntb x19 // nr ldr x21, [x0, #0x8] // lhs_packed ptrue p0.b KAI_ASM_INST(0x25207810) // ptrue pn8.b KAI_ASM_INST(0x25b36571) // whilelt pn9.s, x11, x19, vlx4 - ld1rw { z30.s }, p0/Z, [x0, #0x68] // clamp_min - ld1rw { z31.s }, p0/Z, [x0, #0x6c] // clamp_max + ld1rw { z30.s }, p0/Z, [x0, #0x60] // clamp_min + ld1rw { z31.s }, p0/Z, [x0, #0x64] // clamp_max ldr x14, [x0, #0x38] // k_internal KAI_ASM_LABEL(label_1) // Row Loop - ldr x17, [x0, #0x58] // rhs_row_bytes + ldr x17, [x0, #0x50] // rhs_row_bytes ldr x26, [x0, #0x10] // rhs_packed mov x27, x16 mov x22, #0 @@ -137,7 +137,7 @@ KAI_ASM_LABEL(label_3) // Block Loop add x16, x16, x20 ldr x20, [x0, #0x40] // lhs_stride add x21, x21, x20 - ldr x20, [x0, #0x60] // lhs_end + ldr x20, [x0, #0x58] // lhs_end cmp x21, x20 b.lt label_1 KAI_ASM_INST(0xd503467f) // smstop -- GitLab From 48f5bc54005bc0a5631a18b24de815d980adbd87 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 28 May 2025 16:14:15 +0100 Subject: [PATCH 10/14] Address review comments, update nr calculation Signed-off-by: Anitha Raj --- ..._clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S | 2 +- ...tmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S index 53e5e46e..599e85b1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S @@ -146,7 +146,7 @@ KAI_ASM_LABEL(label_4) KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 addvl x11, x8, #2 KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 - cntb x24 // nr + cntw x24, ALL, MUL #4 // nr sub x13, x13, x24 cmp xzr, x13 b.mi label_2 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S index 603c6648..4c8c3410 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S @@ -57,7 +57,7 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl ldr x16, [x0, #0] // dst mov x11, #0 ldr x15, [x0, #0x28] // n - cntb x19 // nr + cntw x19, ALL, MUL #4 // nr ldr x21, [x0, #0x8] // lhs_packed ptrue p0.b KAI_ASM_INST(0x25207810) // ptrue pn8.b -- GitLab From 03d1f0e2855b5e5e230168626b950dd45f3162b5 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 29 May 2025 17:08:25 +0100 Subject: [PATCH 11/14] Update kernel name and get parameter functions Signed-off-by: Anitha Raj --- CMakeLists.txt | 4 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ...dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c} | 56 +++++++++---------- ...dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h} | 22 ++++---- ...vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S} | 10 ++-- ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 4 +- .../matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp | 6 +- 7 files changed, 52 insertions(+), 52 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/{kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c => kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c} (71%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/{kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h => kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h} (84%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/{kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S => kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S} (94%) diff --git a/CMakeLists.txt b/CMakeLists.txt index ef7d0dc6..09b6769d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,8 +240,8 @@ set(KLEIDIAI_FILES_SME ) set(KLEIDIAI_FILES_SME2_ASM - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S ) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 39ba214e..3dae80f4 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -159,7 +159,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS_ASM = [ - "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa", + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", ] diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c similarity index 71% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c index 6cbc9e6c..59c86c94 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c @@ -11,7 +11,7 @@ #error "This file must be compiled for AArch64, FEAT_SVE2" #else // Architectural features check. -#include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" #include @@ -33,7 +33,7 @@ typedef struct { float clamp_max; // 0x5C } KernelArgs; -void kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 1; // multiple of vector length @@ -66,7 +66,7 @@ inline static size_t kai_k_roundedup(size_t k) { inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); - const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); size_t lhs_packed_stride = mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); // Since the LHS matrix is asymmetric with per-row quantization, we must include the // the number of bytes to hold the zero point value @@ -78,7 +78,7 @@ inline static size_t kai_get_lhs_packed_stride(size_t k) { inline static size_t kai_get_rhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); - const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); size_t rhs_packed_stride = nr * (k_internal * kai_num_bytes_qvalue_rhs); rhs_packed_stride += nr * kai_num_bytes_multiplier_rhs; // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include @@ -90,59 +90,59 @@ inline static size_t kai_get_rhs_packed_stride(size_t k) { return rhs_packed_stride; } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { - return kai_m_step * kai_get_sme_vector_length_u32(); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_m_step * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { - return kai_n_step * kai_get_sme_vector_length_u32(); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { - return kai_mr * kai_get_sme_vector_length_u32(); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { - return kai_nr * kai_get_sme_vector_length_u32(); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()) == 0); - const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); return (m_idx / mr) * kai_get_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { - KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()) == 0); - const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); return (n_idx / nr) * kai_get_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); - KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()) == 0); + KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()) == 0); return ((n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride); } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) { return (m * n * kai_num_bytes_dst_value); } -void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( +void kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( size_t m, // size_t n, // size_t k, // @@ -157,8 +157,8 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( KAI_ASSERT(n > 0); KAI_ASSERT(m > 0); - const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); - const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); KernelArgs args; const size_t k_internal = kai_k_roundedup(k); @@ -176,7 +176,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( args.clamp_min = scalar_min; args.clamp_max = scalar_max; - kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(&args); + kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(&args); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h similarity index 84% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h index bf04a4f1..46cc11f0 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h @@ -25,34 +25,34 @@ extern "C" { /// be processed must be a multiple of m step. /// /// @return the m step value -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void); /// Gets the n step value. /// The micro-kernel can process any N values. However, the starting N index to /// be processed must be a multiple of n step. /// /// @return the n step -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void); /// Gets the mr value, which must be used to pack the LHS matrix /// /// @return the mr value -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void); /// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void); /// Gets the kr value, which must be used to pack the LHS and RHS matrices /// /// @return the kr value -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void); /// Gets the sr value, which must be used to pack the LHS and RHS matrices /// /// @return the sr value -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(void); /// Gets the offset in bytes for the packed LHS matrix, /// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. @@ -63,7 +63,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(v /// @param[in] k Total number of columns in the LHS matrix (not packed). /// /// @return the offset in bytes to the packed LHS matrix -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(size_t m_idx, size_t k); /// Gets the offset in bytes for the packed RHS matrix, /// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. @@ -72,7 +72,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(size_t n_idx, size_t k); /// Gets the offset in bytes for the DST matrix /// @@ -81,7 +81,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the destination(DST) offset in bytes -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( size_t m_idx, size_t n_idx, size_t dst_stride); /// Gets the size in bytes for the destination (DST) matrix. @@ -90,7 +90,7 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme /// @param[in] n Number of columns in the destination (DST) matrix. /// /// @return the destination (DST) matrix size in bytes -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n); /// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. /// @@ -112,7 +112,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_ /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. -void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa( +void kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( size_t m, // size_t n, // size_t k, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S similarity index 94% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S index 599e85b1..32220184 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -35,13 +35,13 @@ #define KAI_ASM_END #endif - KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) KAI_ASM_ALIGN - KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) -KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) -KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) stp x29, x30, [sp, -176]! mov x29, sp stp x19, x20, [sp, 16] @@ -170,6 +170,6 @@ KAI_ASM_LABEL(label_4) ldp x19, x20, [sp, 16] ldp x29, x30, [sp], 176 ret - KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa) + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index 56e711bd..a5e616bd 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -98,7 +98,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(v } size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { - return kai_n_step * kai_get_sme_vector_length_u32(); + return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr; } size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { @@ -106,7 +106,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) } size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { - return kai_nr * kai_get_sme_vector_length_u32(); + return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; } size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index af7181db..d5618a4c 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -15,7 +15,7 @@ #include #include -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi8cxp4vlx8_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" @@ -50,8 +50,8 @@ static const std::array {}; -- GitLab From 412e71feacd410dd253ad794675cb850701a33db Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 29 May 2025 18:04:46 +0100 Subject: [PATCH 12/14] Refactor the register backup in assembly Signed-off-by: Anitha Raj --- ...1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S | 33 +++++++---------- ...8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S | 35 ++++++++----------- 2 files changed, 26 insertions(+), 42 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S index 32220184..3bc2350f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -42,17 +42,12 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) - stp x29, x30, [sp, -176]! - mov x29, sp - stp x19, x20, [sp, 16] - stp x21, x22, [sp, 32] - stp x23, x24, [sp, 48] - stp x25, x26, [sp, 64] - stp x27, x28, [sp, 80] - stp d8, d9, [sp, 96] - stp d10, d11, [sp, 112] - stp d12, d13, [sp, 128] - stp d14, d15, [sp, 144] + stp x19, x20, [sp, -96]! + stp x24, x25, [sp, 16] + stp d8, d9, [sp, 32] + stp d10, d11, [sp, 48] + stp d12, d13, [sp, 64] + stp d14, d15, [sp, 80] KAI_ASM_INST(0xd503477f) // smstart ldr x19, [x0, #0] // dst ldr x20, [x0, #8] // lhs_packed @@ -159,16 +154,12 @@ KAI_ASM_LABEL(label_4) whilelt p0.s, xzr, x12 b.mi label_1 KAI_ASM_INST(0xd503467f) // smstop - ldp d14, d15, [sp, 144] - ldp d12, d13, [sp, 128] - ldp d10, d11, [sp, 112] - ldp d8, d9, [sp, 96] - ldp x27, x28, [sp, 80] - ldp x25, x26, [sp, 64] - ldp x23, x24, [sp, 48] - ldp x21, x22, [sp, 32] - ldp x19, x20, [sp, 16] - ldp x29, x30, [sp], 176 + ldp d14, d15, [sp, 80] + ldp d12, d13, [sp, 64] + ldp d10, d11, [sp, 48] + ldp d8, d9, [sp, 32] + ldp x24, x25, [sp, 16] + ldp x19, x20, [sp], 96 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S index 4c8c3410..aa59c7b2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S @@ -42,17 +42,13 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) - stp x29, x30, [sp, -176]! - mov x29, sp - stp x19, x20, [sp, 16] - stp x21, x22, [sp, 32] - stp x23, x24, [sp, 48] - stp x25, x26, [sp, 64] - stp x27, x28, [sp, 80] - stp d8, d9, [sp, 96] - stp d10, d11, [sp, 112] - stp d12, d13, [sp, 128] - stp d14, d15, [sp, 144] + stp x19, x20, [sp, -112]! + stp x21, x22, [sp, 16] + stp x23, x24, [sp, 32] + stp x25, x26, [sp, 48] + stp x27, x28, [sp, 64] + stp d12, d13, [sp, 80] + stp d14, d15, [sp, 96] KAI_ASM_INST(0xd503477f) // smstart ldr x16, [x0, #0] // dst mov x11, #0 @@ -141,16 +137,13 @@ KAI_ASM_LABEL(label_3) // Block Loop cmp x21, x20 b.lt label_1 KAI_ASM_INST(0xd503467f) // smstop - ldp d14, d15, [sp, 144] - ldp d12, d13, [sp, 128] - ldp d10, d11, [sp, 112] - ldp d8, d9, [sp, 96] - ldp x27, x28, [sp, 80] - ldp x25, x26, [sp, 64] - ldp x23, x24, [sp, 48] - ldp x21, x22, [sp, 32] - ldp x19, x20, [sp, 16] - ldp x29, x30, [sp], 176 + ldp d14, d15, [sp, 96] + ldp d12, d13, [sp, 80] + ldp x27, x28, [sp, 64] + ldp x25, x26, [sp, 48] + ldp x23, x24, [sp, 32] + ldp x21, x22, [sp, 16] + ldp x19, x20, [sp], 112 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) -- GitLab From bfcc3ec7da3491ecfeba100056d28f20518d0683 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 30 May 2025 11:58:07 +0100 Subject: [PATCH 13/14] Refactor Asm kernels Signed-off-by: Anitha Raj --- ...8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c | 6 +- ...1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S | 210 +++++++++--------- ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c | 4 +- ...8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S | 176 +++++++-------- 4 files changed, 198 insertions(+), 198 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c index 59c86c94..c7fee164 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c @@ -18,8 +18,8 @@ #include "kai/kai_common.h" typedef struct { - float* dst; // 0x00 - const void* lhs_packed; // 0x08 + float* dst; // 0 + const void* lhs_packed; // 0x8 const void* rhs_packed; // 0x10 size_t dst_stride_row; // 0x18 size_t m; // 0x20 @@ -30,7 +30,7 @@ typedef struct { size_t m_blk; // 0x48 size_t dst_inc; // 0x50 float clamp_min; // 0x58 - float clamp_max; // 0x5C + float clamp_max; // 0x5c } KernelArgs; void kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(KernelArgs* args_ptr); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S index 3bc2350f..68891f76 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -48,112 +48,112 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vl stp d10, d11, [sp, 48] stp d12, d13, [sp, 64] stp d14, d15, [sp, 80] - KAI_ASM_INST(0xd503477f) // smstart - ldr x19, [x0, #0] // dst - ldr x20, [x0, #8] // lhs_packed - cntw x7 - ptrue p2.b - ld1rw {z30.s}, p2/Z, [x0, #0x58] // clamp_min - ld1rw {z31.s}, p2/Z, [x0, #0x5C] // clamp_max - ldr x12, [x0, #0x20] // m - KAI_ASM_INST(0x25ac17e0) // whilelt p0.s, xzr, x12 -KAI_ASM_LABEL(label_1) // Row Loop - ldr x8, [x0, #0x10] // rhs_packed - mov x9, x19 - ldr x13, [x0, #0x28] // n - cmp x7, x12 - csel x16, x7, x12, lt - lsl x16, x16, #2 - ldr x24, [x0, #0x40] // rhs_row_bytes - add x24, x24, x8 - mov x11, x8 - KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 - addvl x11, x8, #2 - KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 -KAI_ASM_LABEL(label_2) // Column Loop - mov x10, x20 - mov x11, x8 - mov x17, x9 - KAI_ASM_INST(0x25ad67f1) // whilelt pn9.s, xzr, x13, vlx4 - KAI_ASM_INST(0xc00800ff) // zero {za} - ldr x24, [x0, #0x48] // m_blk - add x14, x10, x24 -KAI_ASM_LABEL(label_3) // Block Loop - KAI_ASM_INST(0xa540a144) // ld1w { z4.s }, p0/z, [x10] - KAI_ASM_INST(0x042a502a) // addvl x10, x10, #1 - KAI_ASM_INST(0xa0402168) // ld1h { z8.h - z9.h }, pn8/z, [x11] - KAI_ASM_INST(0xa0884880) // smopa za0.s, p2/m, p2/m, z4.b, z8.b - KAI_ASM_INST(0xa0894881) // smopa za1.s, p2/m, p2/m, z4.b, z9.b - KAI_ASM_INST(0xa041296a) // ld1h { z10.h - z11.h }, pn8/z, [x11, #0x2, mul vl] - KAI_ASM_INST(0xa08a4882) // smopa za2.s, p2/m, p2/m, z4.b, z10.b - KAI_ASM_INST(0xa08b4883) // smopa za3.s, p2/m, p2/m, z4.b, z11.b - KAI_ASM_INST(0x042b508b) // addvl x11, x11, #4 - cmp x10, x14 - b.lt label_3 - KAI_ASM_INST(0xa040c560) // ld1w { z0.s-z3.s }, pn9/z, [x11] - KAI_ASM_INST(0xa041c564) // ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] - KAI_ASM_INST(0xa042c568) // ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl] - KAI_ASM_INST(0x042b518b) // addvl x11, x11, #12 - KAI_ASM_INST(0xc132e000) // scvtf { z0.s-z3.s }, { z0.s-z3.s } - mov x14, #0 - addvl x15, x10, #1 + KAI_ASM_INST(0xd503477f) // smstart + ldr x19, [x0] // dst + ldr x20, [x0, #0x8] // lhs_packed + cntw x7 + ptrue p2.b + ld1rw { z30.s }, p2/z, [x0, #0x58] // clamp_min + ld1rw { z31.s }, p2/z, [x0, #0x5c] // clamp_max + ldr x12, [x0, #0x20] // m + KAI_ASM_INST(0x25ac17e0) // whilelt p0.s, xzr, x12 +KAI_ASM_LABEL(label_1) // Row Loop + ldr x8, [x0, #0x10] // rhs_packed + mov x9, x19 + ldr x13, [x0, #0x28] // n + cmp x7, x12 + csel x16, x7, x12, lt + lsl x16, x16, #2 + ldr x24, [x0, #0x40] // rhs_row_bytes + add x24, x24, x8 + mov x11, x8 + KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 + addvl x11, x8, #0x2 + KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 +KAI_ASM_LABEL(label_2) // Column Loop + mov x10, x20 + mov x11, x8 + mov x17, x9 + KAI_ASM_INST(0x25ad67f1) // whilelt pn9.s, xzr, x13, vlx4 + KAI_ASM_INST(0xc00800ff) // zero {za} + ldr x24, [x0, #0x48] // m_blk + add x14, x10, x24 +KAI_ASM_LABEL(label_3) // Block Loop + ld1w { z4.s }, p0/z, [x10] + addvl x10, x10, #0x1 + KAI_ASM_INST(0xa0402168) // ld1h { z8.h, z9.h }, pn8/z, [x11] + KAI_ASM_INST(0xa0884880) // smopa za0.s, p2/m, p2/m, z4.b, z8.b + KAI_ASM_INST(0xa0894881) // smopa za1.s, p2/m, p2/m, z4.b, z9.b + KAI_ASM_INST(0xa041296a) // ld1h { z10.h, z11.h }, pn10/z, [x11, #0x2, mul vl] + KAI_ASM_INST(0xa08a4882) // smopa za2.s, p2/m, p2/m, z4.b, z10.b + KAI_ASM_INST(0xa08b4883) // smopa za3.s, p2/m, p2/m, z4.b, z11.b + addvl x11, x11, #0x4 + cmp x10, x14 + b.lt label_3 + KAI_ASM_INST(0xa040c560) // ld1w { z0.s - z3.s }, pn9/z, [x11] + KAI_ASM_INST(0xa041c564) // ld1w { z4.s - z7.s }, pn9/z, [x11, #0x4, mul vl] + KAI_ASM_INST(0xa042c568) // ld1w { z8.s - z11.s }, pn9/z, [x11, #0x8, mul vl] + addvl x11, x11, #0xc + KAI_ASM_INST(0xc132e000) // scvtf { z0.s - z3.s }, { z0.s - z3.s } + mov x14, #0x0 // =0 + addvl x15, x10, #0x1 KAI_ASM_LABEL(label_4) - ld1rw {z16.s}, p2/z, [x10] - ld1rw {z17.s}, p2/z, [x15] - add x10, x10, #4 - add x15, x15, #4 - scvtf z16.s, p2/m, z16.s - fmul z24.s, z16.s, z0.s - fmul z25.s, z16.s, z1.s - fmul z26.s, z16.s, z2.s - fmul z27.s, z16.s, z3.s - fmul z20.s, z17.s, z4.s - fmul z21.s, z17.s, z5.s - fmul z22.s, z17.s, z6.s - fmul z23.s, z17.s, z7.s - fmul z24.s, z24.s, z20.s - fmul z25.s, z25.s, z21.s - fmul z26.s, z26.s, z22.s - fmul z27.s, z27.s, z23.s - KAI_ASM_INST(0xc006440c) // mova { z12.b-z15.b }, za0h.b[w14, 0:3] - KAI_ASM_INST(0xc132e18c) // scvtf { z12.s-z15.s }, { z12.s-z15.s } - fmla z24.s, p2/m, z20.s, z12.s - fmla z25.s, p2/m, z21.s, z13.s - fmla z26.s, p2/m, z22.s, z14.s - fmla z27.s, p2/m, z23.s, z15.s - fadd z24.s, p2/m, z24.s, z8.s - fadd z25.s, p2/m, z25.s, z9.s - fadd z26.s, p2/m, z26.s, z10.s - fadd z27.s, p2/m, z27.s, z11.s - KAI_ASM_INST(0xc1bfcbd8) // fclamp { z24.s-z27.s }, z30.s, z31.s - KAI_ASM_INST(0xa060c638) // st1w { z24.s-z27.s }, pn9, [x17] - ldr x24, [x0, #0x18] // dst_stride_row - add x17, x17, x24 - add x14, x14, #4 - cmp x14, x16 - b.lt label_4 - ldr x24, [x0, #0x38] // rhs_stride - add x8, x8, x24 - KAI_ASM_INST(0x04295089) // addvl x9, x9, #4 - ldr x24, [x0, #0x40] // rhs_row_bytes - add x24, x24, x8 - mov x11, x8 - KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 - addvl x11, x8, #2 - KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 - cntw x24, ALL, MUL #4 // nr - sub x13, x13, x24 - cmp xzr, x13 - b.mi label_2 - ldr x24, [x0, #0x30] // lhs_stride - add x20, x20, x24 - ldr x24, [x0, #0x50] // dst_inc - add x19, x19, x24 - cntw x24 // mr - sub x12, x12, x24 - whilelt p0.s, xzr, x12 - b.mi label_1 - KAI_ASM_INST(0xd503467f) // smstop + ld1rw { z16.s }, p2/z, [x10] + ld1rw { z17.s }, p2/z, [x15] + add x10, x10, #0x4 + add x15, x15, #0x4 + scvtf z16.s, p2/m, z16.s + fmul z24.s, z16.s, z0.s + fmul z25.s, z16.s, z1.s + fmul z26.s, z16.s, z2.s + fmul z27.s, z16.s, z3.s + fmul z20.s, z17.s, z4.s + fmul z21.s, z17.s, z5.s + fmul z22.s, z17.s, z6.s + fmul z23.s, z17.s, z7.s + fmul z24.s, z24.s, z20.s + fmul z25.s, z25.s, z21.s + fmul z26.s, z26.s, z22.s + fmul z27.s, z27.s, z23.s + KAI_ASM_INST(0xc006440c) // mov { z12.b - z15.b }, za0h.b[w14, 0x0:0x3] + KAI_ASM_INST(0xc132e18c) // scvtf { z12.s - z15.s }, { z12.s - z15.s } + fmla z24.s, p2/m, z20.s, z12.s + fmla z25.s, p2/m, z21.s, z13.s + fmla z26.s, p2/m, z22.s, z14.s + fmla z27.s, p2/m, z23.s, z15.s + fadd z24.s, p2/m, z24.s, z8.s + fadd z25.s, p2/m, z25.s, z9.s + fadd z26.s, p2/m, z26.s, z10.s + fadd z27.s, p2/m, z27.s, z11.s + KAI_ASM_INST(0xc1bfcbd8) // fclamp { z24.s - z27.s }, z30.s, z31.s + KAI_ASM_INST(0xa060c638) // st1w { z24.s - z27.s }, pn9, [x17] + ldr x24, [x0, #0x18] // dst_stride_row + add x17, x17, x24 + add x14, x14, #0x4 + cmp x14, x16 + b.lt label_4 + ldr x24, [x0, #0x38] // rhs_stride + add x8, x8, x24 + addvl x9, x9, #0x4 + ldr x24, [x0, #0x40] // rhs_row_bytes + add x24, x24, x8 + mov x11, x8 + KAI_ASM_INST(0x25784570) // whilelt pn8.h, x11, x24, vlx2 + addvl x11, x8, #0x2 + KAI_ASM_INST(0x25784572) // whilelt pn10.h, x11, x24, vlx2 + cntw x24, ALL, MUL #0x4 + sub x13, x13, x24 + cmp xzr, x13 + b.mi label_2 + ldr x24, [x0, #0x30] + add x20, x20, x24 + ldr x24, [x0, #0x50] + add x19, x19, x24 + cntw x24 + sub x12, x12, x24 + whilelt p0.s, xzr, x12 + b.mi label_1 + KAI_ASM_INST(0xd503467f) // smstop ldp d14, d15, [sp, 80] ldp d12, d13, [sp, 64] ldp d10, d11, [sp, 48] diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c index a5e616bd..be115774 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c @@ -18,8 +18,8 @@ #include "kai/kai_common.h" typedef struct { - float* dst; // 0x00 - const void* lhs_packed; // 0x08 + float* dst; // 0 + const void* lhs_packed; // 0x8 const void* rhs_packed; // 0x10 size_t dst_stride_row; // 0x18 size_t m; // 0x20 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S index aa59c7b2..fabb060b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S @@ -49,94 +49,94 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl stp x27, x28, [sp, 64] stp d12, d13, [sp, 80] stp d14, d15, [sp, 96] - KAI_ASM_INST(0xd503477f) // smstart - ldr x16, [x0, #0] // dst - mov x11, #0 - ldr x15, [x0, #0x28] // n - cntw x19, ALL, MUL #4 // nr - ldr x21, [x0, #0x8] // lhs_packed - ptrue p0.b - KAI_ASM_INST(0x25207810) // ptrue pn8.b - KAI_ASM_INST(0x25b36571) // whilelt pn9.s, x11, x19, vlx4 - ld1rw { z30.s }, p0/Z, [x0, #0x60] // clamp_min - ld1rw { z31.s }, p0/Z, [x0, #0x64] // clamp_max - ldr x14, [x0, #0x38] // k_internal -KAI_ASM_LABEL(label_1) // Row Loop - ldr x17, [x0, #0x50] // rhs_row_bytes - ldr x26, [x0, #0x10] // rhs_packed - mov x27, x16 - mov x22, #0 - KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 -KAI_ASM_LABEL(label_2) // Column Loop - mov x24, x26 - add x25, x26, x17 - KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 - addvl x28, x24, #4 - KAI_ASM_INST(0x25396793) // whilelt pn11.b, x28, x25, vlx4 - addvl x28, x28, #4 - KAI_ASM_INST(0x25396795) // whilelt pn13.b, x28, x25, vlx4 - addvl x28, x28, #4 - KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 - mov x23, #0 - whilelt p1.b, x23, x14 - KAI_ASM_INST(0xc00800ff) // zero {za} -KAI_ASM_LABEL(label_3) // Block Loop - ld1rqb { z0.b }, p1/z, [x21, x23] - KAI_ASM_INST(0xa0408b10) // ld1b { z16.b - z19.b }, pn10/z, [x24] - KAI_ASM_INST(0xa0418f14) // ld1b { z20.b - z23.b }, pn11/z, [x24, #0x4, mul vl] - KAI_ASM_INST(0xc150f220) // sdot za.s[w11, 0, vgx4], { z16.b - z19.b }, z0.b[0] - KAI_ASM_INST(0xc150f6a0) // sdot za.s[w11, 0, vgx4], { z20.b - z23.b }, z0.b[1] - KAI_ASM_INST(0xa0429710) // ld1b { z16.b - z19.b }, pn13/z, [x24, #0x8, mul vl] - KAI_ASM_INST(0xa0439b14) // ld1b { z20.b - z23.b }, pn14/z, [x24, #0xc, mul vl] - KAI_ASM_INST(0xc150fa20) // sdot za.s[w11, 0, vgx4], { z16.b - z19.b }, z0.b[2] - KAI_ASM_INST(0xc150fea0) // sdot za.s[w11, 0, vgx4], { z20.b - z23.b }, z0.b[3] - addvl x24, x24, #16 - KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 - addvl x28, x24, #4 - KAI_ASM_INST(0x25396793) // whilelt pn11.b, x28, x25, vlx4 - addvl x28, x28, #4 - KAI_ASM_INST(0x25396795) // whilelt pn13.b, x28, x25, vlx4 - addvl x28, x28, #4 - KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 - add x23, x23, #16 - whilelt p1.b, x23, x14 - b.first label_3 - add x28, x21, x14 - ld1rw { z2.s }, p0/z, [x28] - ld1rw { z3.s }, p0/z, [x28, #4] - add x28, x26, x17 - KAI_ASM_INST(0xa040c794) // ld1w { z20.s - z23.s }, pn9/z, [x28] - KAI_ASM_INST(0xa041c798) // ld1w { z24.s - z27.s }, pn9/z, [x28, #0x4, mul vl] - KAI_ASM_INST(0xa042c78c) // ld1w { z12.s - z15.s }, pn9/z, [x28, #0x8, mul vl] - KAI_ASM_INST(0xc0066c04) // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] - mla z4.s, p0/m, z20.s, z2.s - mla z5.s, p0/m, z21.s, z2.s - mla z6.s, p0/m, z22.s, z2.s - mla z7.s, p0/m, z23.s, z2.s - KAI_ASM_INST(0xc132e084) // scvtf { z4.s - z7.s }, { z4.s - z7.s } - fmul z24.s, z24.s, z3.s - fmul z25.s, z25.s, z3.s - fmul z26.s, z26.s, z3.s - fmul z27.s, z27.s, z3.s - fmla z12.s, p0/m, z24.s, z4.s - fmla z13.s, p0/m, z25.s, z5.s - fmla z14.s, p0/m, z26.s, z6.s - fmla z15.s, p0/m, z27.s, z7.s - KAI_ASM_INST(0xc1bfcbcc) // fclamp { z12.s - z15.s }, z30.s, z31.s - KAI_ASM_INST(0xa036d36c) // st1w { z12.s - z15.s }, pn12, [x27, x22, lsl #2] - ldr x20, [x0, #0x48] // rhs_stride - add x26, x26, x20 - addvl x22, x22, #1 - KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 - b.lt label_2 - ldr x20, [x0, #0x18] // dst_stride_row - add x16, x16, x20 - ldr x20, [x0, #0x40] // lhs_stride - add x21, x21, x20 - ldr x20, [x0, #0x58] // lhs_end - cmp x21, x20 - b.lt label_1 - KAI_ASM_INST(0xd503467f) // smstop + KAI_ASM_INST(0xd503477f) // smstart + ldr x16, [x0] // dst + mov x11, #0x0 // =0 + ldr x15, [x0, #0x28] // n + cntw x19, ALL, MUL #4 // nr + ldr x21, [x0, #0x8] // lhs_packed + ptrue p0.b + KAI_ASM_INST(0x25207810) // ptrue pn8.b + KAI_ASM_INST(0x25b36571) // whilelt pn9.s, x11, x19, vlx4 + ld1rw { z30.s }, p0/z, [x0, #0x60] // clamp_min + ld1rw { z31.s }, p0/z, [x0, #0x64] // clamp_max + ldr x14, [x0, #0x38] // k_internal +KAI_ASM_LABEL(label_1) // Row Loop + ldr x17, [x0, #0x50] // rhs_row_bytes + ldr x26, [x0, #0x10] // rhs_packed + mov x27, x16 + mov x22, #0x0 // =0 + KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 +KAI_ASM_LABEL(label_2) // Column Loop + mov x24, x26 + add x25, x26, x17 + KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 + addvl x28, x24, #0x4 + KAI_ASM_INST(0x25396793) // whilelt pn11.b, x28, x25, vlx4 + addvl x28, x28, #0x4 + KAI_ASM_INST(0x25396795) // whilelt pn13.b, x28, x25, vlx4 + addvl x28, x28, #0x4 + KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 + mov x23, #0x0 // =0 + whilelt p1.b, x23, x14 + KAI_ASM_INST(0xc00800ff) // zero {za} +KAI_ASM_LABEL(label_3) // Block Loop + KAI_ASM_INST(0xa41706a0) // ld1rqb { z0.b }, p1/z, [x21, x23] + KAI_ASM_INST(0xa0408b10) // ld1b { z16.b - z19.b }, pn10/z, [x24] + KAI_ASM_INST(0xa0418f14) // ld1b { z20.b - z23.b }, pn11/z, [x24, #0x4, mul vl] + KAI_ASM_INST(0xc150f220) // sdot za.s[w11, 0, vgx4], { z16.b - z19.b }, z0.b[0] + KAI_ASM_INST(0xc150f6a0) // sdot za.s[w11, 0, vgx4], { z20.b - z23.b }, z0.b[1] + KAI_ASM_INST(0xa0429710) // ld1b { z16.b - z19.b }, pn13/z, [x24, #0x8, mul vl] + KAI_ASM_INST(0xa0439b14) // ld1b { z20.b - z23.b }, pn14/z, [x24, #0xc, mul vl] + KAI_ASM_INST(0xc150fa20) // sdot za.s[w11, 0, vgx4], { z16.b - z19.b }, z0.b[2] + KAI_ASM_INST(0xc150fea0) // sdot za.s[w11, 0, vgx4], { z20.b - z23.b }, z0.b[3] + addvl x24, x24, #0x10 + KAI_ASM_INST(0x25396712) // whilelt pn10.b, x24, x25, vlx4 + addvl x28, x24, #0x4 + KAI_ASM_INST(0x25396793) // whilelt pn11.b, x28, x25, vlx4 + addvl x28, x28, #0x4 + KAI_ASM_INST(0x25396795) // whilelt pn13.b, x28, x25, vlx4 + addvl x28, x28, #0x4 + KAI_ASM_INST(0x25396796) // whilelt pn14.b, x28, x25, vlx4 + add x23, x23, #0x10 + whilelt p1.b, x23, x14 + b.first label_3 + add x28, x21, x14 + ld1rw { z2.s }, p0/z, [x28] + ld1rw { z3.s }, p0/z, [x28, #0x4] + add x28, x26, x17 + KAI_ASM_INST(0xa040c794) // ld1w { z20.s - z23.s }, pn9/z, [x28] + KAI_ASM_INST(0xa041c798) // ld1w { z24.s - z27.s }, pn9/z, [x28, #0x4, mul vl] + KAI_ASM_INST(0xa042c78c) // ld1w { z12.s - z15.s }, pn9/z, [x28, #0x8, mul vl] + KAI_ASM_INST(0xc0066c04) // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] + mla z4.s, p0/m, z20.s, z2.s + mla z5.s, p0/m, z21.s, z2.s + mla z6.s, p0/m, z22.s, z2.s + mla z7.s, p0/m, z23.s, z2.s + KAI_ASM_INST(0xc132e084) // scvtf { z4.s - z7.s }, { z4.s - z7.s } + fmul z24.s, z24.s, z3.s + fmul z25.s, z25.s, z3.s + fmul z26.s, z26.s, z3.s + fmul z27.s, z27.s, z3.s + fmla z12.s, p0/m, z24.s, z4.s + fmla z13.s, p0/m, z25.s, z5.s + fmla z14.s, p0/m, z26.s, z6.s + fmla z15.s, p0/m, z27.s, z7.s + KAI_ASM_INST(0xc1bfcbcc) // fclamp { z12.s - z15.s }, z30.s, z31.s + KAI_ASM_INST(0xa036d36c) // st1w { z12.s - z15.s }, pn12, [x27, x22, lsl #2] + ldr x20, [x0, #0x48] // rhs_stride + add x26, x26, x20 + addvl x22, x22, #0x1 + KAI_ASM_INST(0x25af66d4) // whilelt pn12.s, x22, x15, vlx4 + b.lt label_2 + ldr x20, [x0, #0x18] // dst_stride_row + add x16, x16, x20 + ldr x20, [x0, #0x40] // lhs_stride + add x21, x21, x20 + ldr x20, [x0, #0x58] // lhs_end + cmp x21, x20 + b.lt label_1 + KAI_ASM_INST(0xd503467f) // smstop ldp d14, d15, [sp, 96] ldp d12, d13, [sp, 80] ldp x27, x28, [sp, 64] -- GitLab From a022cbf2a7f88a6065030cba00a26e6ab496de07 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 30 May 2025 15:06:10 +0100 Subject: [PATCH 14/14] Rename kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot to kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot Signed-off-by: Anitha Raj --- CMakeLists.txt | 4 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c} | 44 +++++++++---------- ..._qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h} | 22 +++++----- ...8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S} | 10 ++--- .../matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp | 6 +-- 6 files changed, 44 insertions(+), 44 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/{kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c => kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c} (91%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/{kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h => kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h} (96%) rename kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/{kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S => kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S} (98%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 09b6769d..9acbb442 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,8 +242,8 @@ set(KLEIDIAI_FILES_SME set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c - kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S ) set(KLEIDIAI_FILES_SME2 diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 3dae80f4..646e9a2f 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -160,7 +160,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME2_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", - "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot", + "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", ] # buildifier: keep sorted diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c similarity index 91% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c index be115774..6f595762 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c @@ -11,7 +11,7 @@ #error "This file must be compiled for AArch64, FEAT_SVE2" #else // Architectural features check. -#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" +#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" #include @@ -34,7 +34,7 @@ typedef struct { float clamp_max; // 0x64 } KernelArgs; -void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 1; @@ -67,7 +67,7 @@ inline static size_t kai_k_roundedup(size_t k) { inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); - const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(); size_t lhs_packed_stride = mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); // Since the LHS matrix is asymmetric with per-row quantization, we must include the // the number of bytes to hold the zero point value @@ -80,7 +80,7 @@ inline static size_t kai_get_rhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); - const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(); size_t rhs_packed_stride = nr * (k_internal * kai_num_bytes_qvalue_rhs); rhs_packed_stride += nr * kai_num_bytes_multiplier_rhs; @@ -93,55 +93,55 @@ inline static size_t kai_get_rhs_packed_stride(size_t k) { return rhs_packed_stride; } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) { return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) { return kai_nr * kai_get_sme_vector_length_u8() / kai_kr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) { - KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0); return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) { - KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); - const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(); return (n_idx / nr) * kai_get_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); - KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot()) == 0); + KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0); + KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0); return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(size_t m, size_t n) { return m * n * kai_num_bytes_dst_value; } -void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t m, // size_t n, // size_t k, // @@ -161,7 +161,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( const size_t k_internal = kai_k_roundedup(k); const size_t lhs_stride = kai_get_lhs_packed_stride(k); const size_t rhs_stride = kai_get_rhs_packed_stride(k); - const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(); const size_t rhs_row_bytes = nr * k_internal; const size_t lhs_end_ptr = ((size_t)lhs_packed) + (m * lhs_stride); @@ -183,7 +183,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( args.rhs_row_bytes = rhs_row_bytes; args.lhs_end = lhs_end_ptr; - kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(&args); + kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h similarity index 96% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h index c7d85e8c..55628044 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h @@ -25,34 +25,34 @@ extern "C" { /// be processed must be a multiple of m step. /// /// @return the m step value -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void); /// Gets the n step value. /// The micro-kernel can process any N values. However, the starting N index to /// be processed must be a multiple of n step. /// /// @return the n step -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void); /// Gets the mr value, which must be used to pack the LHS matrix /// /// @return the mr value -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void); /// Gets the nr value, which must be used to pack the RHS matrix. /// /// @return the nr value -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void); /// Gets the kr value, which must be used to pack the LHS and RHS matrices /// /// @return the kr value -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void); /// Gets the sr value, which must be used to pack the LHS and RHS matrices /// /// @return the sr value -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void); /// Gets the offset in bytes for the packed LHS matrix, /// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. @@ -63,7 +63,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot(void) /// @param[in] k Total number of columns in the LHS matrix (not packed). /// /// @return the offset in bytes to the packed LHS matrix -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t m_idx, // size_t k); // @@ -74,7 +74,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_ /// @param[in] k The common dimension between the LHS and RHS matrix (K). /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t n_idx, // size_t k); // @@ -85,7 +85,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_ /// @param[in] dst_stride The number of bytes in in each row of the DST matrix /// /// @return the DST offset in bytes -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t m_idx, // size_t n_idx, // size_t dst_stride); // @@ -96,7 +96,7 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sd /// @param[in] n Number of columns in the destination (DST) matrix. /// /// @return the destination (DST) matrix size in bytes -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t m, // size_t n); // @@ -120,7 +120,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. -void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot( +void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot( size_t m, // size_t n, // size_t k, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S similarity index 98% rename from kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S rename to kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S index fabb060b..0f82b0a3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S @@ -35,13 +35,13 @@ #define KAI_ASM_END #endif - KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + KAI_ASM_CODE(matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot) KAI_ASM_ALIGN - KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot) -KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) -KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot) stp x19, x20, [sp, -112]! stp x21, x22, [sp, 16] stp x23, x24, [sp, 32] @@ -145,6 +145,6 @@ KAI_ASM_LABEL(label_3) // Block Loop ldp x21, x22, [sp, 16] ldp x19, x20, [sp], 112 ret - KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot) + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot) KAI_ASM_END diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index d5618a4c..d8db748a 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -16,7 +16,7 @@ #include #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_sdot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" @@ -48,8 +48,8 @@ static const std::array