From 8079e85668f542ec026f891c722a69a9aef83f42 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 4 Jul 2025 15:31:04 +0100 Subject: [PATCH 01/13] SME Matmul Micro-kernel F32 <- (QSI8D32) LHS x (QAI4C32) RHS * Matrix multiplication (MxN) micro-kernels to compute the matrix multiplication of dynamically quantized symmetric signed 8-bit integer with per-block quantization (QSI8D32) LHS matrix and quantized asymmetric 4-bit signed integer with per-block quantization (QAI4C32) RHS matrix and the accumulation of the result into a single-precision (F32) output, optimized for SME2 technology * Matrix multiplication (1xN) micro-kernels to compute the matrix multiplication of dynamically quantized symmetric signed 8-bit integer with per-block quantization (QSI8D32) LHS matrix and quantized asymmetric 4-bit signed integer with per-block quantization (QAI4C32) RHS matrix and the accumulation of the result into a single-precision (F32) output, optimized for SME2 technology Signed-off-by: Anitha Raj --- CMakeLists.txt | 3 + ...32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 469 ++++++++++++++++++ ...32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h | 145 ++++++ ...qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 369 ++++++++++++++ ...qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h | 145 ++++++ ...ai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c | 183 +++++++ ...ai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h | 107 ++++ 7 files changed, 1421 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1146a801..f33fda2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c @@ -319,6 +320,8 @@ set(KLEIDIAI_FILES_SME2_ASM set(KLEIDIAI_FILES_SME2 ${KLEIDIAI_FILES_SME2_ASM} + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c new file mode 100644 index 00000000..7f161ca6 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -0,0 +1,469 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) +#error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" + +#include + +#include "kai/kai_common.h" + +// Compute args +static const size_t kai_m_step = 1; // Multiple of vector length +static const size_t kai_n_step = 4; // Multiple of vector length +// Packing args +static const size_t kai_mr = 1; // Multiple of vector length +static const size_t kai_nr = 4; // Multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +// Look-up table used for int4->int8 convert +static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME(bl % kai_bl == 0); + KAI_ASSUME((k % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + + size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_m_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_n_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, size_t k, size_t bl) { + const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + + KAI_ASSUME((m_idx % m_step) == 0); + + return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + KAI_ASSUME((m_idx % m_step) == 0); + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + + typedef struct { + size_t lhs_packed_stride; + size_t rhs_packed_stride; + size_t mr; + size_t m; + size_t n; + size_t bl; + float min; + float max; + size_t rbias; + } KernelArgs; + + const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const float* lhs_params = (const float*)((const int8_t*)lhs_packed + (mr * bl)); + const float* rhs_params = (const float*)((const int8_t*)rhs_packed + nr * (bl / kai_num_bytes_recip_qvalue_rhs)); + + KernelArgs ka; + ka.mr = mr; + ka.m = m; + ka.n = n; + ka.bl = bl; + ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); + ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + ka.min = scalar_min; + ka.max = scalar_max; + ka.rbias = ka.rhs_packed_stride - nr * kai_num_bytes_bias; + + __asm__ volatile( + // Switch to streaming mode with ZA enabling + " .inst 0xd503477f // smstart \n" + + // Constants + // - SVLs + " cntw x14 \n" + // - ptrue + " ptrue p0.b, all \n" + " .inst 0x25a07810 // ptrue pn8.s \n" + + // Predicate for loading fp32 parameters + " ldr x5, [%x[args_ptr], %[offset_mr]]\n" + " lsl x5, x5, #2 \n" + " whilelt p4.b, xzr, x5 \n" + + // Initialize ZT0 (Lookup table) + " mov x6, %[lut]\n" + " .inst 0xe11f80c0 // ldr zt0, [x6] \n" + + // Initialize the RHS packes and scale pointers + " mov x16, %[rhs_packed] \n" + " mov x17, %[rhs_params] \n" + + // Load the clamp values + " ld1rw z9.s, p0/z, [%x[args_ptr], %[offset_min]] \n" + " ld1rw z10.s, p0/z, [%x[args_ptr], %[offset_max]] \n" + " ldr x4, [%x[args_ptr], %[offset_bl]]\n" + + // Iterate over n (x8) + // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) + " mov x8, #0 \n" + " ldr x0, [%x[args_ptr], %[offset_n]] \n" + " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n" + + " b.none 9f // .LOOP_N_END%= \n" + + " 1: // .LOOP_N_START%=: \n" + + // Iterate over m (x9) + // e.g. for(m_idx = 0; m_idx < m; m_idx+=m_step) + " ldr x9, [%x[args_ptr], %[offset_m]]\n" + + // Initialize the LHS packed and scale pointers + " mov x22, %[lhs_packed] \n" + " mov x23, %[lhs_params] \n" + + // Initialize the DST pointer + " mov x24, %[dst] \n" + + " 2: // .LOOP_M_START%=: \n" + + // Address offset for the left and right quantized values + " mov x20, #0 \n" + " mov x21, #0 \n" + " mov x17, x16 \n" + " mov x3, x22 \n" + + // Number of output rows to store -> min(SVLh, loop M index) + " cmp x9, x14 \n" + " csel x15, x9, x14, lo \n" + " lsl x15, x15, #2 \n" + + // Iterate over all K values + // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) + " mov x10, %[K] \n" + + // Skip processing if K=0 + " cmp x10, #0 \n" + " b.eq 8f // .LOOP_K_END%= \n" + + " 3: // .LOOP_K_START%=: \n" + + // Zeroing of ZA accumulator + " .inst 0xc00800ff // zero {za} \n" + + // Iterate over all values in the block + // k_blk_idx = bl + // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4} + " mov x11, x4 \n" + + " 4: // .LOOP_BL_START%=: \n" + + // Load right matrix row + " .inst 0xa0404222 //ld1w {z2.s - z3.s}, pn8/z, [x17] \n" + " addvl x17, x17, #2 \n" + + // Load left matrix column + " ld1h {z8.h}, p0/z, [x3, x20, lsl #1] \n" + " addvl x3, x3, #1 \n" + + // Convert Int4 -> Int8 + " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n" + " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n" + + // Outer-products + " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n" + " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n" + " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n" + " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n" + + // Decrement the block loop index + " subs x11, x11, #4 \n" + + " b.gt 4b // .LOOP_BL_START%= \n" + + // === End of the block loop === + + // Store loop index + " mov w12, #0 \n" + + // Copy destination pointer for store loop + " mov x25, x24 \n" + + // Load the fp32 sum, scaling factors for the left matrix block + // and zero points, scaling factors for the right matrix block + + " ld1b {z17.b}, p4/z, [x3] \n" // lhs sum + " ld1b {z16.b}, p4/z, [x3, #1, mul vl] \n" // lhs scale + " addvl x3, x3, #2 \n" + + " .inst 0xa040c234 // ld1w { z20.s - z23.s }, pn8/z, [x17]\n" // rhs zp + " .inst 0xa041c220 // ld1w { z0.s - z3.s }, pn8/z, [x17, #4, mul vl ]\n" // rhs scale + " addvl x17, x17, #8 \n" + + // Predicate for the selection of a scaling among the vector + " pfalse p3.b \n" + + " 5: // .LOOP_ZA%=: \n" + + // Select and replicate scaling factor for the right block + " pnext p3.s, p0, p3.s \n" + " clastb z19.s, p3, z19.s, z16.s \n" + " clastb z18.s, p3, z18.s, z17.s \n" + + // Get data from za + " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" + " add w12, w12, #4 \n" + + // Convert from int32 to fp32 + " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" + + // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * + // lhs_sum rhs scaling factor * lhs scaling factor + " fmul z4.s, z0.s, z19.s \n" + " fmul z5.s, z1.s, z19.s \n" + " fmul z6.s, z2.s, z19.s \n" + " fmul z7.s, z3.s, z19.s \n" + + " cmp x10, %[K] \n" + " b.ne 6f // .ACCUMULATE%= \n" + + // rhs zero point * lhs row sum + " fmul z24.s, z20.s, z18.s \n" + " fmul z25.s, z21.s, z18.s \n" + " fmul z26.s, z22.s, z18.s \n" + " fmul z27.s, z23.s, z18.s \n" + + // Applying combined scaling factors to processed block + " fmla z24.s, p0/m, z4.s, z28.s \n" + " fmla z25.s, p0/m, z5.s, z29.s \n" + " fmla z26.s, p0/m, z6.s, z30.s \n" + " fmla z27.s, p0/m, z7.s, z31.s \n" + + "b 7f // .STORE%= \n" + + " 6: // .ACCUMULATE%=: \n" + // Load intermediate result + " .inst 0xa040c738 // ld1w {z24.s-z27.s}, pn9/z, [x25] \n" + + // rhs zero point * lhs row sum + " fmla z24.s, p0/m, z20.s, z18.s \n" + " fmla z25.s, p0/m, z21.s, z18.s \n" + " fmla z26.s, p0/m, z22.s, z18.s \n" + " fmla z27.s, p0/m, z23.s, z18.s \n" + + // Multiply the intermediate results by LHS_SCALE x RHS_SCALE + // and store in the main floating-point accumulator + " fmla z24.s, p0/m, z4.s, z28.s \n" + " fmla z25.s, p0/m, z5.s, z29.s \n" + " fmla z26.s, p0/m, z6.s, z30.s \n" + " fmla z27.s, p0/m, z7.s, z31.s \n" + + "7: // .STORE%=: \n" + // Store the results into memory + " .inst 0xa060c738 // st1w {z24.s-z27.s}, pn9, [x25] \n" + " add x25, x25, %[stride] \n" + + " cmp x12, x15 \n" + " blt 5b // .LOOP_ZA%= \n" + + // Decrement K loop index by bl + " subs x10, x10, x4 \n" + + " b.gt 3b // .LOOP_K_START%= \n" + + " 8: // .LOOP_K_END%=: \n" + + // === End of the K loop === + + // Load bias + " ldr x5, [%x[args_ptr], %[offset_bias]]\n" + " add x5, x5, x16\n" + " .inst 0xa040c0ac // ld1w {z12.s - z15.s}, pn8/z, [x5] \n " + + "mov x5, x24\n" + "mov x12, 0\n" + " 10: \n" // Bias loop + // Load acc + " .inst 0xa040c718 //ld1w {z24.s-z27.s}, pn9/z, [x24] \n" + // Add bias + " fadd z24.s, p0/m, z24.s, z12.s \n" + " fadd z25.s, p0/m, z25.s, z13.s \n" + " fadd z26.s, p0/m, z26.s, z14.s \n" + " fadd z27.s, p0/m, z27.s, z15.s \n" + + // Clamp + " .inst 0xc1aac938 //fclamp { z24.s - z27.s }, z9.s, z10.s \n" + + ".inst 0xa060c718 //st1w {z24.s-z27.s}, pn9, [x24] \n" + + " add x24, x24, %[stride] \n" + " add x12, x12, #4\n" + " cmp x12, x15 \n" + " blt 10b // Bias loop \n" + + "mov x24, x5\n" + + " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" + + // Increment pointer to the quantized values of the lhs matrix + " add x22, x22, x5\n" + + // Increment pointer to the scaling factors of the lhs matrix + " add x23, x23, x5 \n" + + // Update destination pointer + " mov x24, x25 \n" + + // Decrement M loop index + " decw x9, all \n" + + " cmp x9, #0 \n" + " b.gt 2b // .LOOP_M_START%= \n" + + // === End of M loop === + + // Increment output pointer + " incb %[dst], all, mul #4 \n" + + " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n" + + " add x16, x16, x5 \n" + " add x17, x17, x5 \n" + + // Increment N loop index + " incb x8, all \n" + + " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n" + + " b.first 1b // .LOOP_N_START%= \n" + + " 9: // .LOOP_N_END%=: \n" + + // === End of N loop === + + // Exit streaming mode + " .inst 0xd503467f // smstop \n" + : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_params] "+r"(rhs_params) + : [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_params] "r"(lhs_params), [stride] "r"(dst_stride_row), + [lut] "r"(lut), [args_ptr] "r"(&ka), [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)), + [offset_m] "I"(offsetof(KernelArgs, m)), [offset_n] "I"(offsetof(KernelArgs, n)), + [offset_min] "I"(offsetof(KernelArgs, min)), [offset_max] "I"(offsetof(KernelArgs, max)), + [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr)), + [offset_bias] "I"(offsetof(KernelArgs, rbias)), [offset_bl] "I"(offsetof(KernelArgs, bl)) + : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1", + "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", + "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x3", "x4", + "x5", "x6", "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", + "x25", "memory", "cc"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h new file mode 100644 index 00000000..3a9014cc --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h @@ -0,0 +1,145 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (32) quantization (qsi8d32) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl Block length. It must be 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (32) quantization (qsi8d32) and packed +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = 1 VL x 4 VL (Vector Length) +/// +/// Instruction used: SME2 (MOPA) +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c new file mode 100644 index 00000000..5e0ec198 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -0,0 +1,369 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) +#error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" + +#include + +#include "kai/kai_common.h" + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; // Multiple of vector length +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; // Multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +// Look-up table used for int4->int8 convert +static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME(bl % kai_bl == 0); + KAI_ASSUME((k % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_n_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, size_t k, size_t bl) { + const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + KAI_ASSUME((m_idx % m_step) == 0); + + return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, size_t n_idx, size_t dst_stride) { + const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + KAI_ASSUME((m_idx % m_step) == 0); + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + KAI_ASSUME(m == 1); + + KAI_UNUSED(dst_stride_row); + + if (m == 0) { + return; + } + + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + + const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + const float* lhs_params = (const float*)((const int8_t*)lhs_packed + (mr * bl)); + const size_t rhs_params_offset = nr * (bl / kai_num_bytes_recip_qvalue_rhs); + const size_t rhs_bias_offset = rhs_packed_stride - nr * kai_num_bytes_bias; + __asm__ volatile( + // Switch to streaming mode with ZA enabling + " .inst 0xd503477f // smstart \n" + + " ptrue p2.b, all \n" + " .inst 0x25607810 // ptrue pn8.h \n" + + " fmov z28.s, #0.0 \n" + + // Initialize ZT0 (Lookup table) + " mov x9, %[lut] \n" + " .inst 0xe11f8120 // ldr zt0, [x9] \n" + + // Initialize the RHS packed and params pointers + " mov x10, %[rhs_packed] \n" + + // Initialize the DST pointer + " mov x5, %[dst] \n" + + // Load the clamp values + " dup z29.s, %w[scalar_min] \n" + " dup z30.s, %w[scalar_max] \n" + + // Iterate over n (x0) + // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) + " mov x4, #0\n" + " mov x17, %[n] \n" + " .inst 0x25b16491 // whilelt pn9.s, x4, x17, VLx4 \n" + + " b.none 5f // .LOOP_N_END%= \n" + + " 1: // .LOOP_N_START%=: \n" + + // Initialize the LHS packed and params pointers + " mov x2, %[lhs_packed] \n" + " mov x3, %[lhs_params] \n" + " mov x0, x10 \n" + + // Initialize the 4xVL-32bit accumulators to zero + " dup z24.s, #0 \n" + " dup z25.s, #0 \n" + " dup z26.s, #0 \n" + " dup z27.s, #0 \n" + + // Initialize the vector selector for ZA array + " mov w8, #0 \n" + + // Iterate over all K values + // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) + " mov x6, #0 \n" + " whilelt p1.s, x6, %[k] \n" + " b.none 4f // .LOOP_K_END%= \n" + + " 2: // .LOOP_K_START%=: \n" + // Zeroing of inner accumulation array + " .inst 0xc00800ff // zero {za} \n" + + // Iterate over all values in the block + // k_blk_idx = bl + // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 16} + + "mov x13, %[bl] \n" + + "3: // .LOOP_BL_START%=: \n" + // Load the LHS (int8) quantized values + // Load contiguous 16 bytes and replicate. + // For GeMV, we do not interleave the LHS M rows. + " ld1rqb { z0.b }, p2/z , [x2] \n" + " add x2, x2, #16 \n" + + // -- First half + // Load the RHS (int4) quantized values + " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0] \n" + + // Increment the RHS pointer + " addvl x0, x0, #4 \n" + + // Convert Int4 -> Int8 + " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" + " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" + " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" + " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" + + // SDOT indexed + " .inst 0xc15090a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] \n" + " .inst 0xc1509520 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] \n" + + // -- Second half + + // Load the RHS (int4) quantized values + " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0]\n" + + // Increment the RHS pointer + " addvl x0, x0, #4 \n" + + // Convert Int4 -> Int8 + " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" + " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" + " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" + " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" + + // SDOT indexed + " .inst 0xc15098a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] \n" + " .inst 0xc1509d20 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] \n" + + // Decrement the block loop index + "subs x13, x13, #16 \n" + + "b.gt 3b // .LOOP_BL_START%= \n" + + // === End of the block loop === + + // Load Z registers with intermediate values from ZA array + " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n" + + // Convert from int32 to float32 + " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n" + + // Load 1 fp32 LHS sum and scale value and replicate for VL + " ld1rw z1.s, p2/z, [x2] \n" // sum + " ld1rw z2.s, p2/z, [x2, #4] \n" // scale + + // Increment the LHS param pointer by 8 (2 x sizeof(fp32)) + " add x2, x2, #8 \n" + + // Load 4xVL-32bit (fp32) RHS zp and scales. + // If VL=512bit, we load 64 fp32 values, which is equal to the number of output columns (n_step) processed + " .inst 0xa040c004 // ld1w { z4.s - z7.s }, pn8/z, [x0]\n" // zp + " .inst 0xa041c008 // ld1w { z8.s - z11.s }, pn8/z, [x0, #0x4, mul vl ]\n" // scale + + // Increment the RHS pointer + " addvl x0, x0, #8 \n" + + // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * + // lhs_sum rhs zero point * lhs row sum + " fmla z24.s, p2/m, z4.s, z1.s \n" + " fmla z25.s, p2/m, z5.s, z1.s \n" + " fmla z26.s, p2/m, z6.s, z1.s \n" + " fmla z27.s, p2/m, z7.s, z1.s \n" + + // lhs scaling factor * rhs scaling factor + " fmul z8.s, z8.s, z2.s \n" + " fmul z9.s, z9.s, z2.s \n" + " fmul z10.s, z10.s, z2.s \n" + " fmul z11.s, z11.s, z2.s \n" + + // Multiply the intermediate results by LHS_SCALE x RHS_SCALE + // and store in the main floating-point accumulator + " fmla z24.s, p2/m, z16.s, z8.s \n" + " fmla z25.s, p2/m, z17.s, z9.s \n" + " fmla z26.s, p2/m, z18.s, z10.s \n" + " fmla z27.s, p2/m, z19.s, z11.s \n" + + // Increment the number of K values processed and + // go to the next block + " add x6, x6, %[bl] \n" + " whilelt p1.s, x6, %[k] \n" + " b.first 2b // .LOOP_K_START%= \n" + " 4: //.LOOP_K_END%=: \n" + + // Load bias + " .inst 0xa040c014 // ld1w { z20.s - z23.s }, pn8/z, [x0]\n " + + // Add bias + " fadd z24.s, p2/m, z24.s, z20.s \n" + " fadd z25.s, p2/m, z25.s, z21.s \n" + " fadd z26.s, p2/m, z26.s, z22.s \n" + " fadd z27.s, p2/m, z27.s, z23.s \n" + + // Clamp + " .inst 0xc1becbb8 // fclamp { z24.s - z27.s }, z29.s, z30.s \n" + + // Store the results into memory + " .inst 0xa060c4b8 // st1w { z24.s-z27.s }, pn9, [x5] \n" + " incb x4, all \n" + " addvl x5, x5, #4 \n" + + // Update the rhs pointers + " add x10, x10, %[rhs_packed_stride] \n" + + " .inst 0x25b16491 // whilelt pn9.s, x4, %[n], VLx4 \n" + + " b.first 1b // .LOOP_N_START%= \n" + + " 5: // .LOOP_N_END%=: \n" + + // Exit streaming mode + " .inst 0xd503467f //smstop \n" + : + : [lut] "r"(lut), [dst] "r"(dst), [rhs_packed] "r"(rhs_packed), [rhs_params] "r"(rhs_params_offset), + [lhs_packed] "r"(lhs_packed), [lhs_params] "r"(lhs_params), [rhs_packed_stride] "r"(rhs_packed_stride), + [rhs_bias] "r"(rhs_bias_offset), [scalar_min] "r"(scalar_min), [scalar_max] "r"(scalar_max), + [n] "r"((int64_t)n), [k] "r"(k), [bl] "r"(bl) + : "p2", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", + "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", + "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x1", + "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x13", "x17", "memory", "cc"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h new file mode 100644 index 00000000..55cccc39 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h @@ -0,0 +1,145 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (32) quantization (qsi8d32) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl Block length. It must be 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (32) quantization (qsi8d32) and packed +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = 1 x 4 VL (Vector Length) +/// +/// Instruction used: SME2 (sdot) +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c new file mode 100644 index 00000000..789ba2c2 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c @@ -0,0 +1,183 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) && !defined(_M_ARM64) +#error This file must be compiled for AArch64. +#else // Architectural features check. +#include "kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_offset_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_num_bytes_per_block(size_t bl) { + return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kr) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); + return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % nr) == 0); + KAI_UNUSED(kr); + return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_UNUSED(kr); + const size_t num_rows = kai_roundup(n, nr) / nr; + return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl); +} + +void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qai4c32p_params* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % 32) == 0); + KAI_ASSUME(extra_bytes == 0); + + // KAI_ASSUME(sr == 2); + KAI_ASSUME(kr >= 1 && kr <= 16); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(zero != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(params != NULL); + KAI_ASSUME(params->rhs_zero_point == 8); + KAI_ASSUME(params->lhs_zero_point == 1); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t num_blocks_per_row = k / bl; + const size_t rhs_stride = k; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl); + + const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr; + const size_t dst_block_data_size = (bl / 2) * nr; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; + const size_t k_block_length_in_bytes = kr / 2; + const size_t k_interleaved_v = 1U; + + const size_t rhs_zero_point = params->rhs_zero_point; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + float* dst_row_bias = (float*)(dst_row + dst_bias_offset); + + for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) { + uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size; + float* block_dst_zp = (float*)(block_dst_row + dst_block_data_size); + float* block_dst_scale = block_dst_zp + nr; + + for (size_t block_byte_idx = 0; block_byte_idx < dst_block_data_size; ++block_byte_idx) { + const size_t dst_byte_idx = block_byte_idx; + const size_t k_block_idx = dst_byte_idx / k_block_length_in_bytes; + const size_t k_block_byte_idx = dst_byte_idx % k_block_length_in_bytes; + const size_t super_k_block_idx = k_block_idx / nr; + const size_t nr_idx = k_block_idx % nr; + + const size_t k_adjustment = + ((k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes) / k_interleaved_v) * + k_interleaved_v; + const size_t k0_idx = k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2; + const size_t src_addr_byte1 = (k1_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *block_dst_row = dst_qs0; + block_dst_row += sizeof(uint8_t); + } + + // Adjust the zero points and scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + + const float* block_zero = (const float*)zero + num_blocks_per_row * src_row_idx; + const float* block_scale = (const float*)scale + num_blocks_per_row * src_row_idx; + + *block_dst_zp = block_zero[block_idx]; + *block_dst_scale = block_scale[block_idx]; + + block_dst_zp++; + block_dst_scale++; + } + } + // Set the bias + if (bias == NULL) { + memset(dst_row_bias, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + + dst_row_bias[i] = *((const float*)bias + src_row_idx); + } + } + } +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h new file mode 100644 index 00000000..37d5bb1f --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h @@ -0,0 +1,107 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef kai_rhs_pack_nxk_qai4c32p_params +#define kai_rhs_pack_nxk_qai4c32p_params kai_rhs_pack_qs4cxs1s0_param +#endif + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 K values are stored in one byte. These values are stored in blocks +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n_idx, // + size_t rhs_stride); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl // +); + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl // +); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns of the output matrix (N). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] nr The number of N rows to interleave on the same output row. +/// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] zero The zero point. +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + const void* zero, // + const void* bias, // + const void* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qai4c32p_params* params); +#ifdef __cplusplus +} +#endif -- GitLab From e4175b5db54612fead0f54a42f84ed243a3274ef Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 7 Jul 2025 10:37:35 +0100 Subject: [PATCH 02/13] Fix for build failures Signed-off-by: Anitha Raj --- ...mul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 1 + .../kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c | 1 + 2 files changed, 2 insertions(+) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c index 7f161ca6..faeb6d9b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -10,6 +10,7 @@ #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include +#include #include "kai/kai_common.h" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c index 789ba2c2..0fb25bb4 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c @@ -75,6 +75,7 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( KAI_ASSUME((k % bl) == 0); KAI_ASSUME((bl % 32) == 0); KAI_ASSUME(extra_bytes == 0); + KAI_UNUSED(sr); // KAI_ASSUME(sr == 2); KAI_ASSUME(kr >= 1 && kr <= 16); -- GitLab From 5f776d11d7a97162490fecaac7c3cf48a781612e Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 25 Jul 2025 16:55:31 +0100 Subject: [PATCH 03/13] Fix Clang tidy warnings and optimize matmul kernels Signed-off-by: Anitha Raj --- ...32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 10 +++++----- ...amp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 11 ++++++----- ...ck_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c | 5 +++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c index faeb6d9b..87da0fe1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -20,8 +20,8 @@ static const size_t kai_n_step = 4; // Multiple of vector length // Packing args static const size_t kai_mr = 1; // Multiple of vector length static const size_t kai_nr = 4; // Multiple of vector length -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; @@ -324,9 +324,6 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" " add w12, w12, #4 \n" - // Convert from int32 to fp32 - " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" - // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * // lhs_sum rhs scaling factor * lhs scaling factor " fmul z4.s, z0.s, z19.s \n" @@ -334,6 +331,9 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " fmul z6.s, z2.s, z19.s \n" " fmul z7.s, z3.s, z19.s \n" + // Convert from int32 to fp32 + " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" + " cmp x10, %[K] \n" " b.ne 6f // .ACCUMULATE%= \n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c index 5e0ec198..782fe628 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -10,6 +10,7 @@ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" #include +#include #include "kai/kai_common.h" @@ -19,8 +20,8 @@ static const size_t kai_n_step = 4; // Multiple of vector length // Packing args static const size_t kai_mr = 1; static const size_t kai_nr = 4; // Multiple of vector length -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; @@ -282,9 +283,6 @@ void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( // Load Z registers with intermediate values from ZA array " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n" - // Convert from int32 to float32 - " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n" - // Load 1 fp32 LHS sum and scale value and replicate for VL " ld1rw z1.s, p2/z, [x2] \n" // sum " ld1rw z2.s, p2/z, [x2, #4] \n" // scale @@ -300,6 +298,9 @@ void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( // Increment the RHS pointer " addvl x0, x0, #8 \n" + // Convert from int32 to float32 + " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n" + // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * // lhs_sum rhs zero point * lhs row sum " fmla z24.s, p2/m, z4.s, z1.s \n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c index 0fb25bb4..be938d7d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c @@ -77,7 +77,7 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( KAI_ASSUME(extra_bytes == 0); KAI_UNUSED(sr); - // KAI_ASSUME(sr == 2); + KAI_ASSUME(sr == 2); KAI_ASSUME(kr >= 1 && kr <= 16); KAI_ASSUME(rhs != NULL); KAI_ASSUME(zero != NULL); @@ -89,6 +89,7 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) + const size_t block_length = kr / sr; const size_t num_blocks_per_row = k / bl; const size_t rhs_stride = k; const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl); @@ -97,7 +98,7 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( const size_t dst_block_data_size = (bl / 2) * nr; const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; - const size_t k_block_length_in_bytes = kr / 2; + const size_t k_block_length_in_bytes = block_length * sizeof(uint8_t) / 2; const size_t k_interleaved_v = 1U; const size_t rhs_zero_point = params->rhs_zero_point; -- GitLab From 97a3595e0e6fb6de39741831b4bd509c6e631966 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 28 Jul 2025 11:44:39 +0100 Subject: [PATCH 04/13] Add RHS packing kernel with s0s1 ordered inputs * Add new RHS packing kernel to pack int4 inputs with s0s1 order: kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon * Update Changelog Signed-off-by: Anitha Raj --- CHANGELOG.md | 4 + ...ai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c | 185 ++++++++++++++++++ ...ai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h | 107 ++++++++++ ...ai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c | 3 +- 4 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 99cac4f3..cca98fa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New SME micro-kernels: + - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_SME. + - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_SME. + ## v1.12.0 - New Advanced SIMD micro-kernels: diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c new file mode 100644 index 00000000..b85ed97f --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c @@ -0,0 +1,185 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) && !defined(_M_ARM64) +#error This file must be compiled for AArch64. +#else // Architectural features check. +#include +#include + +#include "kai/kai_common.h" +#include "kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h" + +static const size_t kai_num_bytes_offset_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_num_bytes_per_block(size_t bl) { + return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kr) == 0); + KAI_ASSUME((bl % kai_bl_multiple_of) == 0); + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); + return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % nr) == 0); + KAI_UNUSED(kr); + return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_UNUSED(kr); + const size_t num_rows = kai_roundup(n, nr) / nr; + return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl); +} + +void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qai4c32p_params* params) { + KAI_ASSUME(num_groups == 1); + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % 32) == 0); + KAI_ASSUME(extra_bytes == 0); + KAI_UNUSED(sr); + + KAI_ASSUME(sr == 2); + KAI_ASSUME(kr >= 1 && kr <= 16); + KAI_ASSUME(rhs != NULL); + KAI_ASSUME(zero != NULL); + KAI_ASSUME(rhs_packed != NULL); + KAI_ASSUME(params != NULL); + KAI_ASSUME(params->rhs_zero_point == 8); + KAI_ASSUME(params->lhs_zero_point == 1); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t block_length = kr / sr; + const size_t num_blocks_per_row = k / bl; + const size_t rhs_stride = k; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl); + + const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr; + const size_t dst_block_data_size = (bl / 2) * nr; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; + const size_t k_block_length_in_bytes = block_length * sizeof(uint8_t) / 2; + const size_t k_interleaved_v = 1U; + + const size_t rhs_zero_point = params->rhs_zero_point; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + float* dst_row_bias = (float*)(dst_row + dst_bias_offset); + + for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) { + uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size; + float* block_dst_zp = (float*)(block_dst_row + dst_block_data_size); + float* block_dst_scale = block_dst_zp + nr; + + for (size_t block_byte_idx = 0; block_byte_idx < dst_block_data_size; ++block_byte_idx) { + const size_t dst_byte_idx = block_byte_idx; + const size_t k_block_idx = dst_byte_idx / k_block_length_in_bytes; + const size_t k_block_byte_idx = dst_byte_idx % k_block_length_in_bytes; + const size_t super_k_block_idx = k_block_idx / nr; + const size_t nr_idx = k_block_idx % nr; + + const size_t k_adjustment = + ((k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes) / k_interleaved_v) * + k_interleaved_v; + const size_t k0_idx = k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2; + const size_t src_addr_byte1 = (k1_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + const size_t shift_right_x0 = (k0_idx % 2 == 0) ? 4 : 0; + const size_t shift_right_x1 = (k1_idx % 2 == 0) ? 4 : 0; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + const int8_t dst_qs0 = src_x0_lo | + (src_x0_hi << 4); // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + + *block_dst_row = dst_qs0; + block_dst_row += sizeof(uint8_t); + } + + // Adjust the zero points and scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + + const float* block_zero = (const float*)zero + num_blocks_per_row * src_row_idx; + const float* block_scale = (const float*)scale + num_blocks_per_row * src_row_idx; + + *block_dst_zp = block_zero[block_idx]; + *block_dst_scale = block_scale[block_idx]; + + block_dst_zp++; + block_dst_scale++; + } + } + // Set the bias + if (bias == NULL) { + memset(dst_row_bias, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + + dst_row_bias[i] = *((const float*)bias + src_row_idx); + } + } + } +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h new file mode 100644 index 00000000..5d5f5338 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h @@ -0,0 +1,107 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef kai_rhs_pack_nxk_qai4c32p_params +#define kai_rhs_pack_nxk_qai4c32p_params kai_rhs_pack_qs4cxs1s0_param +#endif + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 K values are stored in one byte. These values are stored in blocks +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( + size_t n_idx, // + size_t rhs_stride); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl // +); + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl // +); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns of the output matrix (N). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] nr The number of N rows to interleave on the same output row. +/// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] zero The zero point. +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + const void* zero, // + const void* bias, // + const void* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qai4c32p_params* params); +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c index be938d7d..0fe89fc3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c @@ -148,7 +148,8 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; - const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + const int8_t dst_qs0 = src_x0_lo | + (src_x0_hi << 4); // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) *block_dst_row = dst_qs0; block_dst_row += sizeof(uint8_t); -- GitLab From 34c90ac821b1ff865aa3815c213ce2e162544c94 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 30 Jul 2025 10:12:05 +0100 Subject: [PATCH 05/13] Fix typo in rhs packing kernel names Signed-off-by: Anitha Raj --- ...ck_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c index b85ed97f..cc479bb0 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c @@ -7,11 +7,12 @@ #if !defined(__aarch64__) && !defined(_M_ARM64) #error This file must be compiled for AArch64. #else // Architectural features check. +#include "kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h" + #include #include #include "kai/kai_common.h" -#include "kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h" static const size_t kai_num_bytes_offset_rhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); @@ -40,11 +41,11 @@ inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, s return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); } -size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { KAI_ASSUME((k % 2) == 0); KAI_ASSUME((k % kr) == 0); @@ -54,7 +55,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( size_t n, size_t k, size_t nr, size_t kr, size_t bl) { KAI_ASSUME((k % 2) == 0); KAI_ASSUME((k % kr) == 0); @@ -64,7 +65,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32 return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl); } -void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( +void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qai4c32p_params* params) { -- GitLab From 1da432fda9671c0185e8ebacad199a63c61b56d5 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 30 Jul 2025 12:18:34 +0100 Subject: [PATCH 06/13] Extract inline asm to .S files * To support MSVC compiler move the inline assembly code to pure asm.S files: * kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa * kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot Signed-off-by: Anitha Raj --- CMakeLists.txt | 6 +- ...32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 342 ++---------------- ...vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S | 187 ++++++++++ ...qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 233 ++---------- ...d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot_asm.S | 146 ++++++++ 5 files changed, 399 insertions(+), 515 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot_asm.S diff --git a/CMakeLists.txt b/CMakeLists.txt index f33fda2a..e72fa56a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -312,6 +312,10 @@ set(KLEIDIAI_FILES_SME2_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_asm.S kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c @@ -320,8 +324,6 @@ set(KLEIDIAI_FILES_SME2_ASM set(KLEIDIAI_FILES_SME2 ${KLEIDIAI_FILES_SME2_ASM} - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c - kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c index 87da0fe1..8f7d2ad6 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -13,7 +13,24 @@ #include #include "kai/kai_common.h" - +typedef struct { + float* dst; // 0 + const void* lhs_packed; // 0x8 + const void* rhs_packed; // 0x10 + size_t dst_stride_row; // 0x18 + size_t lhs_packed_stride; // 0x20 + size_t rhs_packed_stride; // 0x28 + size_t bias; // 0x30 + size_t m; // 0x38 + size_t n; // 0x40 + size_t k; // 0x48 + size_t bl; // 0x50 + const int32_t* lut; // 0x58 + float min; // 0x60 + float max; // 0x64 +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 1; // Multiple of vector length static const size_t kai_n_step = 4; // Multiple of vector length @@ -157,314 +174,25 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( return; } - typedef struct { - size_t lhs_packed_stride; - size_t rhs_packed_stride; - size_t mr; - size_t m; - size_t n; - size_t bl; - float min; - float max; - size_t rbias; - } KernelArgs; - - const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); - const float* lhs_params = (const float*)((const int8_t*)lhs_packed + (mr * bl)); - const float* rhs_params = (const float*)((const int8_t*)rhs_packed + nr * (bl / kai_num_bytes_recip_qvalue_rhs)); - - KernelArgs ka; - ka.mr = mr; - ka.m = m; - ka.n = n; - ka.bl = bl; - ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); - ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); - ka.min = scalar_min; - ka.max = scalar_max; - ka.rbias = ka.rhs_packed_stride - nr * kai_num_bytes_bias; - - __asm__ volatile( - // Switch to streaming mode with ZA enabling - " .inst 0xd503477f // smstart \n" - - // Constants - // - SVLs - " cntw x14 \n" - // - ptrue - " ptrue p0.b, all \n" - " .inst 0x25a07810 // ptrue pn8.s \n" - - // Predicate for loading fp32 parameters - " ldr x5, [%x[args_ptr], %[offset_mr]]\n" - " lsl x5, x5, #2 \n" - " whilelt p4.b, xzr, x5 \n" - - // Initialize ZT0 (Lookup table) - " mov x6, %[lut]\n" - " .inst 0xe11f80c0 // ldr zt0, [x6] \n" - - // Initialize the RHS packes and scale pointers - " mov x16, %[rhs_packed] \n" - " mov x17, %[rhs_params] \n" - - // Load the clamp values - " ld1rw z9.s, p0/z, [%x[args_ptr], %[offset_min]] \n" - " ld1rw z10.s, p0/z, [%x[args_ptr], %[offset_max]] \n" - " ldr x4, [%x[args_ptr], %[offset_bl]]\n" - - // Iterate over n (x8) - // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) - " mov x8, #0 \n" - " ldr x0, [%x[args_ptr], %[offset_n]] \n" - " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n" - - " b.none 9f // .LOOP_N_END%= \n" - - " 1: // .LOOP_N_START%=: \n" - - // Iterate over m (x9) - // e.g. for(m_idx = 0; m_idx < m; m_idx+=m_step) - " ldr x9, [%x[args_ptr], %[offset_m]]\n" - - // Initialize the LHS packed and scale pointers - " mov x22, %[lhs_packed] \n" - " mov x23, %[lhs_params] \n" - - // Initialize the DST pointer - " mov x24, %[dst] \n" - - " 2: // .LOOP_M_START%=: \n" - - // Address offset for the left and right quantized values - " mov x20, #0 \n" - " mov x21, #0 \n" - " mov x17, x16 \n" - " mov x3, x22 \n" - - // Number of output rows to store -> min(SVLh, loop M index) - " cmp x9, x14 \n" - " csel x15, x9, x14, lo \n" - " lsl x15, x15, #2 \n" - - // Iterate over all K values - // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) - " mov x10, %[K] \n" - - // Skip processing if K=0 - " cmp x10, #0 \n" - " b.eq 8f // .LOOP_K_END%= \n" - - " 3: // .LOOP_K_START%=: \n" - - // Zeroing of ZA accumulator - " .inst 0xc00800ff // zero {za} \n" - - // Iterate over all values in the block - // k_blk_idx = bl - // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4} - " mov x11, x4 \n" - - " 4: // .LOOP_BL_START%=: \n" - - // Load right matrix row - " .inst 0xa0404222 //ld1w {z2.s - z3.s}, pn8/z, [x17] \n" - " addvl x17, x17, #2 \n" - - // Load left matrix column - " ld1h {z8.h}, p0/z, [x3, x20, lsl #1] \n" - " addvl x3, x3, #1 \n" - - // Convert Int4 -> Int8 - " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n" - " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n" - - // Outer-products - " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n" - " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n" - " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n" - " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n" - - // Decrement the block loop index - " subs x11, x11, #4 \n" - - " b.gt 4b // .LOOP_BL_START%= \n" - - // === End of the block loop === - - // Store loop index - " mov w12, #0 \n" - - // Copy destination pointer for store loop - " mov x25, x24 \n" - - // Load the fp32 sum, scaling factors for the left matrix block - // and zero points, scaling factors for the right matrix block - - " ld1b {z17.b}, p4/z, [x3] \n" // lhs sum - " ld1b {z16.b}, p4/z, [x3, #1, mul vl] \n" // lhs scale - " addvl x3, x3, #2 \n" - - " .inst 0xa040c234 // ld1w { z20.s - z23.s }, pn8/z, [x17]\n" // rhs zp - " .inst 0xa041c220 // ld1w { z0.s - z3.s }, pn8/z, [x17, #4, mul vl ]\n" // rhs scale - " addvl x17, x17, #8 \n" - - // Predicate for the selection of a scaling among the vector - " pfalse p3.b \n" - - " 5: // .LOOP_ZA%=: \n" - - // Select and replicate scaling factor for the right block - " pnext p3.s, p0, p3.s \n" - " clastb z19.s, p3, z19.s, z16.s \n" - " clastb z18.s, p3, z18.s, z17.s \n" - - // Get data from za - " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" - " add w12, w12, #4 \n" - - // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * - // lhs_sum rhs scaling factor * lhs scaling factor - " fmul z4.s, z0.s, z19.s \n" - " fmul z5.s, z1.s, z19.s \n" - " fmul z6.s, z2.s, z19.s \n" - " fmul z7.s, z3.s, z19.s \n" - - // Convert from int32 to fp32 - " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" - - " cmp x10, %[K] \n" - " b.ne 6f // .ACCUMULATE%= \n" - - // rhs zero point * lhs row sum - " fmul z24.s, z20.s, z18.s \n" - " fmul z25.s, z21.s, z18.s \n" - " fmul z26.s, z22.s, z18.s \n" - " fmul z27.s, z23.s, z18.s \n" - - // Applying combined scaling factors to processed block - " fmla z24.s, p0/m, z4.s, z28.s \n" - " fmla z25.s, p0/m, z5.s, z29.s \n" - " fmla z26.s, p0/m, z6.s, z30.s \n" - " fmla z27.s, p0/m, z7.s, z31.s \n" - - "b 7f // .STORE%= \n" - - " 6: // .ACCUMULATE%=: \n" - // Load intermediate result - " .inst 0xa040c738 // ld1w {z24.s-z27.s}, pn9/z, [x25] \n" - - // rhs zero point * lhs row sum - " fmla z24.s, p0/m, z20.s, z18.s \n" - " fmla z25.s, p0/m, z21.s, z18.s \n" - " fmla z26.s, p0/m, z22.s, z18.s \n" - " fmla z27.s, p0/m, z23.s, z18.s \n" - - // Multiply the intermediate results by LHS_SCALE x RHS_SCALE - // and store in the main floating-point accumulator - " fmla z24.s, p0/m, z4.s, z28.s \n" - " fmla z25.s, p0/m, z5.s, z29.s \n" - " fmla z26.s, p0/m, z6.s, z30.s \n" - " fmla z27.s, p0/m, z7.s, z31.s \n" - - "7: // .STORE%=: \n" - // Store the results into memory - " .inst 0xa060c738 // st1w {z24.s-z27.s}, pn9, [x25] \n" - " add x25, x25, %[stride] \n" - - " cmp x12, x15 \n" - " blt 5b // .LOOP_ZA%= \n" - - // Decrement K loop index by bl - " subs x10, x10, x4 \n" - - " b.gt 3b // .LOOP_K_START%= \n" - - " 8: // .LOOP_K_END%=: \n" - - // === End of the K loop === - - // Load bias - " ldr x5, [%x[args_ptr], %[offset_bias]]\n" - " add x5, x5, x16\n" - " .inst 0xa040c0ac // ld1w {z12.s - z15.s}, pn8/z, [x5] \n " - - "mov x5, x24\n" - "mov x12, 0\n" - " 10: \n" // Bias loop - // Load acc - " .inst 0xa040c718 //ld1w {z24.s-z27.s}, pn9/z, [x24] \n" - // Add bias - " fadd z24.s, p0/m, z24.s, z12.s \n" - " fadd z25.s, p0/m, z25.s, z13.s \n" - " fadd z26.s, p0/m, z26.s, z14.s \n" - " fadd z27.s, p0/m, z27.s, z15.s \n" - - // Clamp - " .inst 0xc1aac938 //fclamp { z24.s - z27.s }, z9.s, z10.s \n" - - ".inst 0xa060c718 //st1w {z24.s-z27.s}, pn9, [x24] \n" - - " add x24, x24, %[stride] \n" - " add x12, x12, #4\n" - " cmp x12, x15 \n" - " blt 10b // Bias loop \n" - - "mov x24, x5\n" - - " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" - - // Increment pointer to the quantized values of the lhs matrix - " add x22, x22, x5\n" - - // Increment pointer to the scaling factors of the lhs matrix - " add x23, x23, x5 \n" - - // Update destination pointer - " mov x24, x25 \n" - - // Decrement M loop index - " decw x9, all \n" - - " cmp x9, #0 \n" - " b.gt 2b // .LOOP_M_START%= \n" - - // === End of M loop === - - // Increment output pointer - " incb %[dst], all, mul #4 \n" - - " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n" - - " add x16, x16, x5 \n" - " add x17, x17, x5 \n" - - // Increment N loop index - " incb x8, all \n" - - " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n" - - " b.first 1b // .LOOP_N_START%= \n" - - " 9: // .LOOP_N_END%=: \n" - - // === End of N loop === - // Exit streaming mode - " .inst 0xd503467f // smstop \n" - : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_params] "+r"(rhs_params) - : [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_params] "r"(lhs_params), [stride] "r"(dst_stride_row), - [lut] "r"(lut), [args_ptr] "r"(&ka), [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)), - [offset_m] "I"(offsetof(KernelArgs, m)), [offset_n] "I"(offsetof(KernelArgs, n)), - [offset_min] "I"(offsetof(KernelArgs, min)), [offset_max] "I"(offsetof(KernelArgs, max)), - [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr)), - [offset_bias] "I"(offsetof(KernelArgs, rbias)), [offset_bl] "I"(offsetof(KernelArgs, bl)) - : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1", - "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", - "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x3", "x4", - "x5", "x6", "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", - "x25", "memory", "cc"); + KernelArgs args; + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.dst_stride_row = dst_stride_row; + args.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); + args.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + args.bias = args.rhs_packed_stride - nr * kai_num_bytes_bias; + args.m = m; + args.n = n; + args.k = k; + args.bl = bl; + args.lut = lut; + args.min = scalar_min; + args.max = scalar_max; + + kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S new file mode 100644 index 00000000..1da2db43 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -0,0 +1,187 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + KAI_ASM_CODE(matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) + KAI_ASM_ALIGN + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) + stp x19, x20, [sp, -128 ]! + stp x21, x22, [sp, 16] + stp x23, x24, [sp, 32] + stp x25, x26, [sp, 48] + stp d8, d9, [sp, 64] + stp d10, d11, [sp, 80] + stp d12, d13, [sp, 96] + stp d14, d15, [sp, 112] + KAI_ASM_INST(0xd503477f) // smstart + cntw x14 + ptrue p0.b, all + KAI_ASM_INST(0x25a07810) // ptrue pn8.s + cntw x5 //mr + lsl x5, x5, #2 + whilelt p4.b, xzr, x5 + ldr x6, [x0, #0x58] // lut + KAI_ASM_INST(0xe11f80c0) // ldr zt0, [x6] + ldr x19, [x0, #0x10] // rhs_packed + ld1rw z9.s, p0/z, [x0, #0x60] + ld1rw z10.s, p0/z, [x0, #0x64] + ldr x4, [x0, #0x50] // bl + ldr x21, [x0, #0x18] // dst_stride_row + ldr x20, [x0] // dst + mov x8, #0 + ldr x13, [x0, #0x40] // n + ldr x23, [x0, #0x48] // k + KAI_ASM_INST(0x25ad6511) // whilelt pn9.s, x8, x13, VLx4 + b.none label_9 +KAI_ASM_LABEL(label_1) // N Loop + ldr x9, [x0, #0x38] // m + ldr x22, [x0, #0x8] // lhs_packed + mov x24, x20 +KAI_ASM_LABEL(label_2) // M Loop + mov x26, x19 + mov x3, x22 + cmp x9, x14 + csel x15, x9, x14, lo + lsl x15, x15, #2 + ldr x10, [x0, #0x48] // k + cmp x10, #0 + b.eq label_8 +KAI_ASM_LABEL(label_3) // K Loop + KAI_ASM_INST(0xc00800ff) // zero {za} + mov x11, x4 +KAI_ASM_LABEL(label_4) // Block Loop + KAI_ASM_INST(0xa0404342) //ld1w {z2.s - z3.s}, pn8/z, [x26] + addvl x26, x26, #2 + ld1h {z8.h}, p0/z, [x3] + addvl x3, x3, #1 + KAI_ASM_INST(0xc08a4044) // luti4 {z4.b - z5.b}, zt0, z2[0] + KAI_ASM_INST(0xc08a4066) // luti4 {z6.b - z7.b}, zt0, z3[0] + KAI_ASM_INST(0xa0840100) // smopa za0.s, p0/m, p0/m, z8.b, z4.b + KAI_ASM_INST(0xa0850101) // smopa za1.s, p0/m, p0/m, z8.b, z5.b + KAI_ASM_INST(0xa0860102) // smopa za2.s, p0/m, p0/m, z8.b, z6.b + KAI_ASM_INST(0xa0870103) // smopa za3.s, p0/m, p0/m, z8.b, z7.b + subs x11, x11, #4 + b.gt label_4 + mov w12, #0 + mov x25, x24 + ld1b {z17.b}, p4/z, [x3] // lhs sum + ld1b {z16.b}, p4/z, [x3, #1, mul vl] // lhs scale + addvl x3, x3, #2 + KAI_ASM_INST(0xa040c354) // ld1w { z20.s - z23.s }, pn8/z, [x26] // rhs zp + KAI_ASM_INST(0xa041c340) // ld1w { z0.s - z3.s }, pn8/z, [x26, #4, mul vl ] // rhs scale + addvl x26, x26, #8 + pfalse p3.b +KAI_ASM_LABEL(label_5) + pnext p3.s, p0, p3.s + clastb z19.s, p3, z19.s, z16.s + clastb z18.s, p3, z18.s, z17.s + KAI_ASM_INST(0xc006041c) // mova {z28.b-z31.b}, za0h.b[w12, 0:3] + add w12, w12, #4 + fmul z4.s, z0.s, z19.s + fmul z5.s, z1.s, z19.s + fmul z6.s, z2.s, z19.s + fmul z7.s, z3.s, z19.s + KAI_ASM_INST(0xc132e39c) // scvtf {z28.s-z31.s}, {z28.s-z31.s} + cmp x10, x23 + b.ne label_6 + fmul z24.s, z20.s, z18.s + fmul z25.s, z21.s, z18.s + fmul z26.s, z22.s, z18.s + fmul z27.s, z23.s, z18.s + fmla z24.s, p0/m, z4.s, z28.s + fmla z25.s, p0/m, z5.s, z29.s + fmla z26.s, p0/m, z6.s, z30.s + fmla z27.s, p0/m, z7.s, z31.s + b label_7 +KAI_ASM_LABEL(label_6) + KAI_ASM_INST(0xa040c738) // ld1w {z24.s-z27.s}, pn9/z, [x25] + fmla z24.s, p0/m, z20.s, z18.s + fmla z25.s, p0/m, z21.s, z18.s + fmla z26.s, p0/m, z22.s, z18.s + fmla z27.s, p0/m, z23.s, z18.s + fmla z24.s, p0/m, z4.s, z28.s + fmla z25.s, p0/m, z5.s, z29.s + fmla z26.s, p0/m, z6.s, z30.s + fmla z27.s, p0/m, z7.s, z31.s +KAI_ASM_LABEL(label_7) + KAI_ASM_INST(0xa060c738) // st1w {z24.s-z27.s}, pn9, [x25] + add x25, x25, x21 + cmp x12, x15 + blt label_5 + subs x10, x10, x4 + b.gt label_3 +KAI_ASM_LABEL(label_8) + ldr x5, [x0,0x30] + add x5, x5, x19 + KAI_ASM_INST(0xa040c0ac) // ld1w {z12.s - z15.s}, pn8/z, [x5] + mov x5, x24 + mov x12, 0 +KAI_ASM_LABEL(label_10) // Bias loop + KAI_ASM_INST(0xa040c718) // ld1w {z24.s-z27.s}, pn9/z, [x24] + fadd z24.s, p0/m, z24.s, z12.s + fadd z25.s, p0/m, z25.s, z13.s + fadd z26.s, p0/m, z26.s, z14.s + fadd z27.s, p0/m, z27.s, z15.s + KAI_ASM_INST(0xc1aac938) // fclamp { z24.s - z27.s }, z9.s, z10.s + KAI_ASM_INST(0xa060c718) // st1w {z24.s-z27.s}, pn9, [x24] + add x24, x24, x21 + add x12, x12, #4 + cmp x12, x15 + blt label_10 + mov x24, x5 + ldr x5, [x0, #0x20] + add x22, x22, x5 + mov x24, x25 + decw x9, all + cmp x9, #0 + b.gt label_2 + incb x20, all, mul #4 + ldr x5, [x0, #0x28] // rhs_packed_stride + add x19, x19, x5 + incb x8, all + KAI_ASM_INST(0x25ad6511) // whilelt pn9.s, x8, x13, VLx4 + b.first label_1 +KAI_ASM_LABEL(label_9) + KAI_ASM_INST(0xd503467f) // smstop + ldp d14, d15, [sp, 112] + ldp d12, d13, [sp, 96] + ldp d10, d11, [sp, 80] + ldp d8, d9, [sp, 64] + ldp x25, x26, [sp, 48] + ldp x23, x24, [sp, 32] + ldp x21, x22, [sp, 16] + ldp x19, x20, [sp],128 + ret +KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) +KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c index 782fe628..1aaebccc 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -13,7 +13,20 @@ #include #include "kai/kai_common.h" - +typedef struct { + float* dst; // 0 + const void* lhs_packed; // 0x8 + const void* rhs_packed; // 0x10 + size_t rhs_packed_stride; // 0x18 + size_t n; // 0x20 + size_t k; // 0x28 + size_t bl; // 0x30 + const int32_t* lut; // 0x38 + float min; // 0x40 + float max; // 0x44 +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 1; static const size_t kai_n_step = 4; // Multiple of vector length @@ -160,211 +173,19 @@ void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( return; } - const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); - - const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); - const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); - - const float* lhs_params = (const float*)((const int8_t*)lhs_packed + (mr * bl)); - const size_t rhs_params_offset = nr * (bl / kai_num_bytes_recip_qvalue_rhs); - const size_t rhs_bias_offset = rhs_packed_stride - nr * kai_num_bytes_bias; - __asm__ volatile( - // Switch to streaming mode with ZA enabling - " .inst 0xd503477f // smstart \n" - - " ptrue p2.b, all \n" - " .inst 0x25607810 // ptrue pn8.h \n" - - " fmov z28.s, #0.0 \n" - - // Initialize ZT0 (Lookup table) - " mov x9, %[lut] \n" - " .inst 0xe11f8120 // ldr zt0, [x9] \n" - - // Initialize the RHS packed and params pointers - " mov x10, %[rhs_packed] \n" - - // Initialize the DST pointer - " mov x5, %[dst] \n" - - // Load the clamp values - " dup z29.s, %w[scalar_min] \n" - " dup z30.s, %w[scalar_max] \n" - - // Iterate over n (x0) - // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) - " mov x4, #0\n" - " mov x17, %[n] \n" - " .inst 0x25b16491 // whilelt pn9.s, x4, x17, VLx4 \n" - - " b.none 5f // .LOOP_N_END%= \n" - - " 1: // .LOOP_N_START%=: \n" - - // Initialize the LHS packed and params pointers - " mov x2, %[lhs_packed] \n" - " mov x3, %[lhs_params] \n" - " mov x0, x10 \n" - - // Initialize the 4xVL-32bit accumulators to zero - " dup z24.s, #0 \n" - " dup z25.s, #0 \n" - " dup z26.s, #0 \n" - " dup z27.s, #0 \n" - - // Initialize the vector selector for ZA array - " mov w8, #0 \n" - - // Iterate over all K values - // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) - " mov x6, #0 \n" - " whilelt p1.s, x6, %[k] \n" - " b.none 4f // .LOOP_K_END%= \n" - - " 2: // .LOOP_K_START%=: \n" - // Zeroing of inner accumulation array - " .inst 0xc00800ff // zero {za} \n" - - // Iterate over all values in the block - // k_blk_idx = bl - // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 16} - - "mov x13, %[bl] \n" - - "3: // .LOOP_BL_START%=: \n" - // Load the LHS (int8) quantized values - // Load contiguous 16 bytes and replicate. - // For GeMV, we do not interleave the LHS M rows. - " ld1rqb { z0.b }, p2/z , [x2] \n" - " add x2, x2, #16 \n" - - // -- First half - // Load the RHS (int4) quantized values - " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0] \n" - - // Increment the RHS pointer - " addvl x0, x0, #4 \n" - - // Convert Int4 -> Int8 - " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" - " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" - " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" - " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" - - // SDOT indexed - " .inst 0xc15090a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] \n" - " .inst 0xc1509520 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] \n" - - // -- Second half - - // Load the RHS (int4) quantized values - " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0]\n" - - // Increment the RHS pointer - " addvl x0, x0, #4 \n" - - // Convert Int4 -> Int8 - " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" - " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" - " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" - " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" - - // SDOT indexed - " .inst 0xc15098a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] \n" - " .inst 0xc1509d20 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] \n" - - // Decrement the block loop index - "subs x13, x13, #16 \n" - - "b.gt 3b // .LOOP_BL_START%= \n" - - // === End of the block loop === - - // Load Z registers with intermediate values from ZA array - " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n" - - // Load 1 fp32 LHS sum and scale value and replicate for VL - " ld1rw z1.s, p2/z, [x2] \n" // sum - " ld1rw z2.s, p2/z, [x2, #4] \n" // scale - - // Increment the LHS param pointer by 8 (2 x sizeof(fp32)) - " add x2, x2, #8 \n" - - // Load 4xVL-32bit (fp32) RHS zp and scales. - // If VL=512bit, we load 64 fp32 values, which is equal to the number of output columns (n_step) processed - " .inst 0xa040c004 // ld1w { z4.s - z7.s }, pn8/z, [x0]\n" // zp - " .inst 0xa041c008 // ld1w { z8.s - z11.s }, pn8/z, [x0, #0x4, mul vl ]\n" // scale - - // Increment the RHS pointer - " addvl x0, x0, #8 \n" - - // Convert from int32 to float32 - " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n" - - // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * - // lhs_sum rhs zero point * lhs row sum - " fmla z24.s, p2/m, z4.s, z1.s \n" - " fmla z25.s, p2/m, z5.s, z1.s \n" - " fmla z26.s, p2/m, z6.s, z1.s \n" - " fmla z27.s, p2/m, z7.s, z1.s \n" - - // lhs scaling factor * rhs scaling factor - " fmul z8.s, z8.s, z2.s \n" - " fmul z9.s, z9.s, z2.s \n" - " fmul z10.s, z10.s, z2.s \n" - " fmul z11.s, z11.s, z2.s \n" - - // Multiply the intermediate results by LHS_SCALE x RHS_SCALE - // and store in the main floating-point accumulator - " fmla z24.s, p2/m, z16.s, z8.s \n" - " fmla z25.s, p2/m, z17.s, z9.s \n" - " fmla z26.s, p2/m, z18.s, z10.s \n" - " fmla z27.s, p2/m, z19.s, z11.s \n" - - // Increment the number of K values processed and - // go to the next block - " add x6, x6, %[bl] \n" - " whilelt p1.s, x6, %[k] \n" - " b.first 2b // .LOOP_K_START%= \n" - " 4: //.LOOP_K_END%=: \n" - - // Load bias - " .inst 0xa040c014 // ld1w { z20.s - z23.s }, pn8/z, [x0]\n " - - // Add bias - " fadd z24.s, p2/m, z24.s, z20.s \n" - " fadd z25.s, p2/m, z25.s, z21.s \n" - " fadd z26.s, p2/m, z26.s, z22.s \n" - " fadd z27.s, p2/m, z27.s, z23.s \n" - - // Clamp - " .inst 0xc1becbb8 // fclamp { z24.s - z27.s }, z29.s, z30.s \n" - - // Store the results into memory - " .inst 0xa060c4b8 // st1w { z24.s-z27.s }, pn9, [x5] \n" - " incb x4, all \n" - " addvl x5, x5, #4 \n" - - // Update the rhs pointers - " add x10, x10, %[rhs_packed_stride] \n" - - " .inst 0x25b16491 // whilelt pn9.s, x4, %[n], VLx4 \n" - - " b.first 1b // .LOOP_N_START%= \n" - - " 5: // .LOOP_N_END%=: \n" - - // Exit streaming mode - " .inst 0xd503467f //smstop \n" - : - : [lut] "r"(lut), [dst] "r"(dst), [rhs_packed] "r"(rhs_packed), [rhs_params] "r"(rhs_params_offset), - [lhs_packed] "r"(lhs_packed), [lhs_params] "r"(lhs_params), [rhs_packed_stride] "r"(rhs_packed_stride), - [rhs_bias] "r"(rhs_bias_offset), [scalar_min] "r"(scalar_min), [scalar_max] "r"(scalar_max), - [n] "r"((int64_t)n), [k] "r"(k), [bl] "r"(bl) - : "p2", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", - "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", - "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x1", - "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x13", "x17", "memory", "cc"); + KernelArgs args; + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + args.n = n; + args.k = k; + args.bl = bl; + args.lut = lut; + args.min = scalar_min; + args.max = scalar_max; + + kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot_asm.S new file mode 100644 index 00000000..82625257 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot_asm.S @@ -0,0 +1,146 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + KAI_ASM_CODE(matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot) + KAI_ASM_ALIGN + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot) +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot) + stp x19, x20, [sp, -112 ]! + stp x21, x22, [sp, 16] + stp x23, x24, [sp, 32] + stp d8, d9, [sp, 48] + stp d10, d11, [sp, 64] + stp d12, d13, [sp, 80] + stp d14, d15, [sp, 96] + KAI_ASM_INST(0xd503477f) // smstart + ptrue p2.b, all + KAI_ASM_INST(0x25607810) // ptrue pn8.h + fmov z28.s, #0.0 + ldr x9, [x0, #0x38] // lut + KAI_ASM_INST(0xe11f8120) // ldr zt0, [x9] + ldr x10, [x0, #0x10] // rhs_packed + ldr x5, [x0] //dst + ld1rw z29.s, p2/z, [x0, #0x40] // min + ld1rw z30.s, p2/z, [x0, #0x44] // max + mov x4, #0 + ldr x24, [x0, #0x20] // n + KAI_ASM_INST(0x25b86491) // whilelt pn9.s, x4, x24, vlx4 + ldr x19, [x0, #0x28] // k + ldr x20, [x0, #0x30] // bl + b.none label_5 +KAI_ASM_LABEL(label_1) // N loop + ldr x21, [x0, #0x8] // lhs_packed + mov x23, x10 + dup z24.s, #0 + dup z25.s, #0 + dup z26.s, #0 + dup z27.s, #0 + mov w8, #0 + mov x6, #0 + whilelt p1.s, x6, x19 + b.none label_4 +KAI_ASM_LABEL(label_2) // K Loop + KAI_ASM_INST(0xc00800ff) // zero {za} + mov x13, x20 +KAI_ASM_LABEL(label_3) // BL loop + ld1rqb { z0.b }, p2/z , [x21] + add x21, x21, #16 + KAI_ASM_INST(0xa040a2ec) // ld1h { z12.h - z15.h }, pn8/z, [x23] + addvl x23, x23, #4 + KAI_ASM_INST(0xc08a4184) // luti4 { z4.b, z5.b }, zt0, z12[0] + KAI_ASM_INST(0xc08a41a6) // luti4 { z6.b, z7.b }, zt0, z13[0] + KAI_ASM_INST(0xc08a41c8) // luti4 { z8.b, z9.b }, zt0, z14[0] + KAI_ASM_INST(0xc08a41ea) // luti4 { z10.b, z11.b }, zt0, z15[0] + KAI_ASM_INST(0xc15090a0) // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] + KAI_ASM_INST(0xc1509520) // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] + KAI_ASM_INST(0xa040a2ec) // ld1h { z12.h - z15.h }, pn8/z, [x23] + addvl x23, x23, #4 + KAI_ASM_INST(0xc08a4184) // luti4 { z4.b, z5.b }, zt0, z12[0] + KAI_ASM_INST(0xc08a41a6) // luti4 { z6.b, z7.b }, zt0, z13[0] + KAI_ASM_INST(0xc08a41c8) // luti4 { z8.b, z9.b }, zt0, z14[0] + KAI_ASM_INST(0xc08a41ea) // luti4 { z10.b, z11.b }, zt0, z15[0] + KAI_ASM_INST(0xc15098a0) // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] + KAI_ASM_INST(0xc1509d20) // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] + subs x13, x13, #16 + b.gt label_3 + KAI_ASM_INST(0xc0060c10) // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] + ld1rw z1.s, p2/z, [x21] // sum + ld1rw z2.s, p2/z, [x21, #4] // scale + add x21, x21, #8 + KAI_ASM_INST(0xa040c2e4) // ld1w { z4.s - z7.s }, pn8/z, [x23] // zp + KAI_ASM_INST(0xa041c2e8) // ld1w { z8.s - z11.s }, pn8/z, [x23, #0x4, mul vl ] // scale + addvl x23, x23, #8 + KAI_ASM_INST(0xc132e210) // scvtf{z16.s - z19.s}, {z16.s - z19.s} + fmla z24.s, p2/m, z4.s, z1.s + fmla z25.s, p2/m, z5.s, z1.s + fmla z26.s, p2/m, z6.s, z1.s + fmla z27.s, p2/m, z7.s, z1.s + fmul z8.s, z8.s, z2.s + fmul z9.s, z9.s, z2.s + fmul z10.s, z10.s, z2.s + fmul z11.s, z11.s, z2.s + fmla z24.s, p2/m, z16.s, z8.s + fmla z25.s, p2/m, z17.s, z9.s + fmla z26.s, p2/m, z18.s, z10.s + fmla z27.s, p2/m, z19.s, z11.s + add x6, x6, x20 + whilelt p1.s, x6, x19 + b.first label_2 +KAI_ASM_LABEL(label_4) + KAI_ASM_INST(0xa040c2f4) // ld1w { z20.s - z23.s }, pn8/z, [x23] + fadd z24.s, p2/m, z24.s, z20.s + fadd z25.s, p2/m, z25.s, z21.s + fadd z26.s, p2/m, z26.s, z22.s + fadd z27.s, p2/m, z27.s, z23.s + KAI_ASM_INST(0xc1becbb8) // fclamp { z24.s - z27.s }, z29.s, z30.s + KAI_ASM_INST(0xa060c4b8) // st1w { z24.s-z27.s }, pn9, [x5] + incb x4, all + addvl x5, x5, #4 + ldr x22, [x0, #0x18] + add x10, x10, x22 + KAI_ASM_INST(0x25b86491) // whilelt pn9.s, x4, x24, VLx4 + b.first label_1 +KAI_ASM_LABEL(label_5) + KAI_ASM_INST(0xd503467f) // smstop + ldp d14, d15, [sp, 96] + ldp d12, d13, [sp, 80] + ldp d10, d11, [sp, 64] + ldp d8, d9, [sp, 48] + ldp x23, x24, [sp, 32] + ldp x21, x22, [sp, 16] + ldp x19, x20, [sp],112 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot) + + KAI_ASM_END -- GitLab From 4312ee63efb174b05a00857ffbfcc416dc64e433 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Wed, 30 Jul 2025 14:24:30 +0100 Subject: [PATCH 07/13] update test to include new RHS packing functions Signed-off-by: Evie Wright --- CMakeLists.txt | 1 + test/common/test_suite.hpp | 16 + ...atmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp | 330 ++++++++++++------ 3 files changed, 235 insertions(+), 112 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e72fa56a..ffbc8b44 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c diff --git a/test/common/test_suite.hpp b/test/common/test_suite.hpp index 5432e9ac..f3464a88 100644 --- a/test/common/test_suite.hpp +++ b/test/common/test_suite.hpp @@ -38,6 +38,22 @@ kai_run_##rhs_pack \ } \ } + +#define UKERNEL_RHS_PACK_VARIANT(rhs_pack) \ + { \ + kai_get_rhs_packed_size_##rhs_pack, \ + kai_get_rhs_packed_offset_##rhs_pack, \ + kai_get_rhs_offset_##rhs_pack, \ + kai_run_##rhs_pack \ + } + +#define UKERNEL_LHS_PACK_VARIANT(lhs_pack) \ + { \ + kai_get_lhs_packed_size_##lhs_pack, \ + kai_get_lhs_packed_offset_##lhs_pack, \ + kai_get_lhs_offset_##lhs_pack, \ + kai_run_##lhs_pack \ + } // clang-format on namespace kai::test { diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp index dfb5b2fb..8253162d 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" @@ -21,7 +23,10 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.h" #include "test/common/buffer.hpp" +#include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/int4.hpp" #include "test/common/matmul_test_common.hpp" @@ -38,24 +43,136 @@ namespace kai::test { -static const std::array, 4> +// clang-format off +#define UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1(name, features_check, lhs_pack, rhs_pack, s0s1_input) \ +{ \ + {UKERNEL_MATMUL_VARIANT(name), "kai_matmul_" #name, features_check}, \ + UKERNEL_LHS_PACK_VARIANT(lhs_pack), \ + UKERNEL_RHS_PACK_VARIANT(rhs_pack), \ + s0s1_input \ +} +// clang-format on + +// Interface for the LHS and RHS packed size and packing functions +using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon); +using kai_get_rhs_packed_size_func_t = + decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); +using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon); +using kai_get_rhs_packed_offset_func_t = + decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); +using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon); +using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); +using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon); +using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); + +// Micro-kernel interface +struct kai_qai4c32p_pack_functions { + kai_get_rhs_packed_size_func_t packed_size; + kai_get_rhs_packed_offset_func_t get_packed_offset; + kai_get_rhs_offset_func_t get_offset; + kai_run_rhs_pack_func_t run_pack; +}; + +struct kai_qsi8d32p_pack_functions { + kai_get_lhs_packed_size_func_t packed_size; + kai_get_lhs_packed_offset_func_t get_packed_offset; + kai_get_lhs_offset_func_t get_offset; + kai_run_lhs_pack_func_t run_pack; +}; + +template +struct UkernelMatmulPackVariantWithS0S1 { + /// Interface for matmul variant. + UkernelVariant ukernel; + + L lhs_pack_interface; + R rhs_pack_interface; + + bool rhs_s0s1_input; + + UkernelMatmulPackVariantWithS0S1() = delete; +}; + +static const std::array< + UkernelMatmulPackVariantWithS0S1< + kai_matmul_clamp_f32_qsi8d32p_qai4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qai4c32p_pack_functions>, + 8> variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p = { - {{UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod), - "kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod}, - {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm), - "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", cpu_has_i8mm}, - {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod), - "kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", cpu_has_dotprod}, - {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod), - "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod", cpu_has_dotprod}}}; + {UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod, cpu_has_dotprod, + lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32pscalef32_f32_neon, + rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod, cpu_has_dotprod, + lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, + lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f32_neon, + rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon, false), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, + lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon, false), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f32_neon, + rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon, true), + UKERNEL_MATMUL_PACK_VARIANT_WITH_S0S1( + clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, + lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon, true)}}; + +// Executes the LHS packing micro-kernel. +static inline std::tuple pack_lhs_qsi8d32p( + const kai_qsi8d32p_pack_functions& pack_interface, size_t M, size_t K, size_t bl, size_t mr, size_t kr, size_t sr, + const Buffer& lhs_values_qsi8, size_t stride, size_t rect_start_row, size_t rect_height) { + const auto imp_packed_lhs_size = pack_interface.packed_size(M, K, bl, mr, kr, sr); + Buffer imp_packed_lhs(imp_packed_lhs_size, 0); + + auto lhs_offset = pack_interface.get_offset(rect_start_row, stride); + auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, bl, mr, kr, sr); + + kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + rect_height, K, bl, mr, kr, sr, 0, reinterpret_cast(lhs_values_qsi8.data() + lhs_offset), stride, + imp_packed_lhs.data() + lhs_packed_offset); + + return {std::move(imp_packed_lhs), lhs_packed_offset}; +} + +// Executes the RHS packing micro-kernel. +static inline std::tuple pack_rhs_qai4c32p( + const kai_qai4c32p_pack_functions& pack_interface, size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, + const Buffer& rhs_values_qai4, const bool has_bias, const Buffer& biases, const Buffer& rhs_scales, + const Buffer& rhs_zp, bool s0s1_input, size_t rect_start_row) { + // Cast to unsigned int + auto rhs_qau4s1s0 = cast_qsu4_qsi4(rhs_values_qai4.data(), N * K); + + const auto imp_packed_rhs_size = pack_interface.packed_size(N, K, nr, kr, bl); + Buffer imp_packed_rhs(imp_packed_rhs_size); + auto rhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, nr, kr, bl); + + // Runs the RHS packing micro-kernel. + kai_rhs_pack_nxk_qai4c32p_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + pack_interface.run_pack( + 1, N, K, nr, kr, sr, bl, + reinterpret_cast(s0s1_input ? convert_s0s1_s1s0(rhs_qau4s1s0).data() : rhs_qau4s1s0.data()), + rhs_zp.data(), has_bias ? biases.data() : nullptr, rhs_scales.data(), imp_packed_rhs.data(), 0, ¶ms); + + return {std::move(imp_packed_rhs), rhs_packed_offset}; +} -class MatMulTest_f32_qsi8d32p_qai4c32p : public ::testing::TestWithParam {}; +using MatMulTestPortionedParamsWithBias_WithBL = std::tuple; +class MatMulTest_f32_qsi8d32p_qai4c32p : public ::testing::TestWithParam {}; TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { // Verify LHS quant and pack int8 kernel behaves same for int4 and int8 matmul kernels, // when the block-depth is same for different values of kr, sr. - const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); const std::uint32_t seed = 0; @@ -63,17 +180,20 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { const size_t M = matmul_shape.m; const size_t N = matmul_shape.n; const size_t K = matmul_shape.k; - const size_t bl = 32; - const auto mr = ukernel_variant.interface.get_mr(); - const auto nr = ukernel_variant.interface.get_nr(); - const auto kr = ukernel_variant.interface.get_kr(); - const auto sr = ukernel_variant.interface.get_sr(); + if (K % bl != 0) { + GTEST_SKIP() << "K must be a multiple of bl"; + } + + const auto mr = ukernel_variant.ukernel.interface.get_mr(); + const auto nr = ukernel_variant.ukernel.interface.get_nr(); + const auto kr = ukernel_variant.ukernel.interface.get_kr(); + const auto sr = ukernel_variant.ukernel.interface.get_sr(); - auto m_step = ukernel_variant.interface.get_m_step(); + auto m_step = ukernel_variant.ukernel.interface.get_m_step(); ASSERT_TRUE(m_step % mr == 0); - auto n_step = ukernel_variant.interface.get_n_step(); + auto n_step = ukernel_variant.ukernel.interface.get_n_step(); ASSERT_TRUE(n_step % nr == 0); const auto rect = portion.compute_portion(M, N, m_step, n_step); @@ -88,47 +208,32 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { // Runs the LHS packing micro-kernel. const auto lhs_start_row = rect.start_row(); - const auto imp_packed_lhs_size = - kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr, sr); - Buffer imp_packed_lhs(imp_packed_lhs_size, 0); - auto lhs_stride = K * sizeof(float); - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, lhs_stride); - auto lhs_packed_offset = - kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr, sr); - kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( - rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), - lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qsi8d32p( + ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, lhs_stride, lhs_start_row, rect.height()); const size_t kr_qsi8 = kr / sr; const size_t sr_qsi8 = 1; - const auto imp_packed_lhs_qsi8_size = - kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr_qsi8, sr_qsi8); - Buffer imp_packed_lhs_qsi8(imp_packed_lhs_qsi8_size, 0); - auto lhs_qsi8_packed_offset = - kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr_qsi8, sr_qsi8); + auto [imp_packed_lhs_qsi8, lhs_qsi8_packed_offset] = pack_lhs_qsi8d32p( + ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr_qsi8, sr_qsi8, ref_lhs, lhs_stride, lhs_start_row, + rect.height()); ASSERT_EQ(lhs_qsi8_packed_offset, lhs_packed_offset); - kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( - rect.height() /* m */, K, bl, mr, kr_qsi8, sr_qsi8, 0, - reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, - imp_packed_lhs_qsi8.data() + lhs_qsi8_packed_offset); - auto* imp_packed_lhs_ptr = reinterpret_cast(imp_packed_lhs.data()); auto* imp_packed_lhs_qsi8_ptr = reinterpret_cast(imp_packed_lhs_qsi8.data()); - for (size_t i = 0; i < imp_packed_lhs_qsi8_size; i++) { + for (size_t i = 0; i < ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); i++) { ASSERT_EQ(imp_packed_lhs_ptr[i], imp_packed_lhs_qsi8_ptr[i]); } } TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { - const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); - if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { GTEST_SKIP() << "Unsupported CPU feature"; } @@ -137,21 +242,24 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { const size_t M = matmul_shape.m; const size_t N = matmul_shape.n; const size_t K = matmul_shape.k; - const size_t bl = 32; - const auto mr = ukernel_variant.interface.get_mr(); - const auto nr = ukernel_variant.interface.get_nr(); - const auto kr = ukernel_variant.interface.get_kr(); - const auto sr = ukernel_variant.interface.get_sr(); + if (K % bl != 0) { + GTEST_SKIP() << "K must be a multiple of bl"; + } + + const auto mr = ukernel_variant.ukernel.interface.get_mr(); + const auto nr = ukernel_variant.ukernel.interface.get_nr(); + const auto kr = ukernel_variant.ukernel.interface.get_kr(); + const auto sr = ukernel_variant.ukernel.interface.get_sr(); if (mr == 1 && M > 1) { GTEST_SKIP() << "Kernel does not support M != 1"; } - auto m_step = ukernel_variant.interface.get_m_step(); + auto m_step = ukernel_variant.ukernel.interface.get_m_step(); ASSERT_TRUE(m_step % mr == 0); - auto n_step = ukernel_variant.interface.get_n_step(); + auto n_step = ukernel_variant.ukernel.interface.get_n_step(); ASSERT_TRUE(n_step % nr == 0); const auto rect = portion.compute_portion(M, N, m_step, n_step); @@ -189,22 +297,13 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { // Runs the LHS packing micro-kernel. const auto lhs_start_row = rect.start_row(); - const auto imp_packed_lhs_size = - kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(M, K, bl, mr, kr, sr); - Buffer imp_packed_lhs(imp_packed_lhs_size, 0); - - auto lhs_stride = K * sizeof(float); - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, lhs_stride); - auto lhs_packed_offset = - kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(lhs_start_row, K, bl, mr, kr, sr); - auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K, bl); + auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qsi8d32p( + ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, K * sizeof(float), lhs_start_row, + rect.height()); + auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( - rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), - lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); - // Prepare the offsets as the RHS packing kernel expects the scaled zero-points in float. const size_t num_blocks_per_row = round_up_division(K, bl); const size_t ref_zp_size = N * num_blocks_per_row; @@ -216,78 +315,68 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { reinterpret_cast(ref_rhs_scales.data())[i]; } - // Cast to unsigned int - auto ref_rhs_qau4 = cast_qsu4_qsi4(ref_rhs_qai4.data(), N * K); - - // Reorder the nibble pairing to s0s1 - const auto ref_rhs_qau4s0s1 = convert_s0s1_s1s0(ref_rhs_qau4); - - const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(N, K, nr, kr, bl); - Buffer imp_packed_rhs(imp_packed_rhs_size); const auto rhs_start_row = rect.start_col(); - auto rhs_packed_offset = - kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(rhs_start_row, K, nr, kr, bl); - auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); - ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); - - // Runs the RHS packing micro-kernel. - kai_rhs_pack_nxk_qai4c32p_params params{}; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; + auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qai4c32p( + ukernel_variant.rhs_pack_interface, N, K, bl, nr, kr, sr, ref_rhs_qai4, has_bias, ref_biases, ref_rhs_scales, + ref_rhs_zp_f32, ukernel_variant.rhs_s0s1_input, rhs_start_row); - kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon( - 1, N, K, nr, kr, sr, bl, reinterpret_cast(ref_rhs_qau4s0s1.data()), ref_rhs_zp_f32.data(), - has_bias ? ref_biases.data() : nullptr, ref_rhs_scales.data(), imp_packed_rhs.data(), 0, ¶ms); + auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); const auto dst_stride_row = N * sizeof(float); const auto dst_stride_col = sizeof(float); const auto dst_offset = - ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; ASSERT_EQ(dst_offset, ref_dst_offset); // Runs the GEMM micro-kernel. - const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); Buffer imp_dst(imp_dst_size); - ukernel_variant.interface.run_matmul( + ukernel_variant.ukernel.interface.run_matmul( rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, clamp_min, clamp_max); // Compares the output of the micro-kernels against the output of the reference implementation for the portion // tested. - for (size_t y = 0; y < rect.height(); ++y) { - for (size_t x = 0; x < rect.width(); ++x) { - const auto imp_value = - read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto ref_value = - read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); } INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest_f32_qsi8d32p_qai4c32p, testing::Combine( testing::Range(0, variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.size()), testing::Values( - MatMulShape{1, 2, 32}, // - MatMulShape{1, 3, 32}, // - MatMulShape{1, 4, 32}, // - MatMulShape{1, 5, 32}, // - MatMulShape{3, 3, 32}, // - MatMulShape{4, 4, 32}, // - MatMulShape{5, 5, 32}, // - MatMulShape{32, 64, 64}, // - MatMulShape{16, 32, 64}, // - MatMulShape{8, 32, 64}, // - MatMulShape{15, 32, 32}, // + MatMulShape{1, 64, 32}, // + MatMulShape{1, 63, 32}, // + MatMulShape{1, 65, 32}, // + MatMulShape{1, 64, 64}, // + MatMulShape{1, 64, 128}, // + MatMulShape{1, 128, 32}, // + MatMulShape{1, 128, 128}, // + MatMulShape{1, 2, 32}, // + MatMulShape{1, 3, 32}, // + MatMulShape{1, 4, 32}, // + MatMulShape{1, 5, 32}, // + MatMulShape{3, 3, 32}, // + MatMulShape{4, 4, 32}, // + MatMulShape{5, 5, 32}, // + MatMulShape{32, 128, 32}, // + MatMulShape{15, 64, 64}, // + MatMulShape{17, 64, 64}, // + MatMulShape{16, 63, 64}, // + MatMulShape{16, 64, 64}, // + MatMulShape{16, 65, 64}, // + MatMulShape{32, 64, 64}, // + MatMulShape{16, 32, 64}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 32, 32}, // MatMulShape{77, 99, 64}), + testing::Values(32, 64), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. @@ -300,12 +389,29 @@ INSTANTIATE_TEST_SUITE_P( testing::Bool()), [](const auto& info) { const auto variant_idx = std::get<0>(info.param); - const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).name}; + const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).ukernel.name}; const auto shape = std::get(info.param); - const auto portion = std::get<2>(info.param); - const auto has_bias = std::get<3>(info.param); + const auto bl = std::get<2>(info.param); + const auto portion = std::get<3>(info.param); + const auto has_bias = std::get<4>(info.param); + + std::ostringstream sstream; + sstream << name << "__"; + PrintTo(shape, &sstream); + sstream << "__BL_" << bl << "_"; + if (has_bias) { + sstream << "_withBias_"; + } else { + sstream << "_noBias_"; + } + if (variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).rhs_s0s1_input) { + sstream << "_RHS_s0s1__"; + } else { + sstream << "_RHS_s1s0__"; + } + PrintTo(portion, &sstream); - return test_description(name, shape, portion, has_bias); + return sstream.str(); }); } // namespace kai::test -- GitLab From 423bffd3b7f839cd40f185b40923ba28596c732a Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 30 Jul 2025 14:32:15 +0100 Subject: [PATCH 08/13] Update Bazel build and fix clang tidy errors Signed-off-by: Anitha Raj --- kai/ukernels/matmul/BUILD.bazel | 4 ++++ ..._rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c | 5 +++-- ..._rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c | 5 +++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 498dbd76..7d2b492a 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -36,6 +36,8 @@ NEON_KERNELS = [ "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", "pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon", + "pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon", + "pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon", "pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", @@ -197,6 +199,8 @@ SME2_KERNELS_ASM = [ "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", + "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa", + "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot", "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c index cc479bb0..7ea51547 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c @@ -148,8 +148,9 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; - const int8_t dst_qs0 = src_x0_lo | - (src_x0_hi << 4); // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) *block_dst_row = dst_qs0; block_dst_row += sizeof(uint8_t); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c index 0fe89fc3..e9e34b75 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c @@ -148,8 +148,9 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; - const int8_t dst_qs0 = src_x0_lo | - (src_x0_hi << 4); // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) *block_dst_row = dst_qs0; block_dst_row += sizeof(uint8_t); -- GitLab From 1e9175fbbad697e01ee3e5ba024c2ad5e488dae6 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 30 Jul 2025 15:08:18 +0100 Subject: [PATCH 09/13] Update Architectural feature guards Signed-off-by: Anitha Raj --- ...ul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 2 +- ..._matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c index 8f7d2ad6..47ebc6b9 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) +#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. #else // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c index 1aaebccc..7f1b2d8a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) +#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) && !defined(_M_ARM64) #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. #else // Architectural features check. -- GitLab From 96bab098167fbca52feb6463fd68417b5eb88c77 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Wed, 30 Jul 2025 15:28:02 +0100 Subject: [PATCH 10/13] insert missing arch check Signed-off-by: Evie Wright --- test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp index 8253162d..d22c072b 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp @@ -133,7 +133,7 @@ static inline std::tuple pack_lhs_qsi8d32p( auto lhs_offset = pack_interface.get_offset(rect_start_row, stride); auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, bl, mr, kr, sr); - kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( + pack_interface.run_pack( rect_height, K, bl, mr, kr, sr, 0, reinterpret_cast(lhs_values_qsi8.data() + lhs_offset), stride, imp_packed_lhs.data() + lhs_packed_offset); @@ -175,6 +175,10 @@ TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); + if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { + GTEST_SKIP() << "Unsupported CPU feature"; + } + const std::uint32_t seed = 0; const size_t M = matmul_shape.m; -- GitLab From c8648a4ae9cd8d3b84db35b265bb1d5a45be293b Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 30 Jul 2025 15:40:56 +0100 Subject: [PATCH 11/13] Fix whitespace in asm end directive in GEMM kernel Signed-off-by: Anitha Raj --- ...p_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S index 1da2db43..c0dabb40 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -183,5 +183,6 @@ KAI_ASM_LABEL(label_9) ldp x21, x22, [sp, 16] ldp x19, x20, [sp],128 ret -KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) -KAI_ASM_END + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa) + + KAI_ASM_END -- GitLab From 6b3c0bba7e18a2cd4b7408e768541bcb72e3f258 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 30 Jul 2025 16:09:45 +0100 Subject: [PATCH 12/13] Update architecture guards Signed-off-by: Anitha Raj --- ..._clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 4 ++-- ...atmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c index 47ebc6b9..6f5c41c3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -3,8 +3,8 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) && !defined(_M_ARM64) -#error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c index 7f1b2d8a..6f176662 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -3,8 +3,8 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) && !defined(_M_ARM64) -#error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. #include "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" -- GitLab From d777e7fd106a8bb0dece076b65ab6a7e522182d2 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 31 Jul 2025 09:37:44 +0100 Subject: [PATCH 13/13] Add paranthesis to expression Signed-off-by: Anitha Raj --- ...kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c | 2 +- ...kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c index 7ea51547..6b876f83 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon.c @@ -98,7 +98,7 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s0s1_f32_f32_f32_neon( const size_t dst_block_data_size = (bl / 2) * nr; const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; - const size_t k_block_length_in_bytes = block_length * sizeof(uint8_t) / 2; + const size_t k_block_length_in_bytes = (block_length * sizeof(uint8_t)) / 2; const size_t k_interleaved_v = 1U; const size_t rhs_zero_point = params->rhs_zero_point; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c index e9e34b75..e69befbd 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon.c @@ -98,7 +98,7 @@ void kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon( const size_t dst_block_data_size = (bl / 2) * nr; const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; - const size_t k_block_length_in_bytes = block_length * sizeof(uint8_t) / 2; + const size_t k_block_length_in_bytes = (block_length * sizeof(uint8_t)) / 2; const size_t k_interleaved_v = 1U; const size_t rhs_zero_point = params->rhs_zero_point; -- GitLab