From e6be1984fbfdb565763885e9708fe58cecc0000d Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 14 Jul 2025 14:55:29 +0100 Subject: [PATCH 1/3] SME Matmul Micro-kernel F16 <- (QSI8D32) LHS x (QAI4C32) RHS * Matrix multiplication (1xN) micro-kernels to compute the matrix multiplication of dynamically quantized symmetric signed 8-bit integer with per-block quantization (QSI8D32) LHS matrix and quantized asymmetric 4-bit signed integer with per-block quantization (QAI4C32) RHS matrix and the accumulation of the result into a half-precision (F16) output, optimized for SME2 technology Signed-off-by: Anitha Raj --- CHANGELOG.md | 1 + CMakeLists.txt | 1 + ...qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 372 ++++++++++++++++++ ...qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h | 145 +++++++ 4 files changed, 519 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 99cac4f3..f970362c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4C32 RHS with BF16 output, optimized for FEAT_DotProd. - New SME micro-kernels: - Matrix multiplication (1xN) of F32 LHS and RHS with F32 output, using instructions compatible with FEAT_SME. + - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_SME. - Matrix multiplication (1xN) of F16 LHS and RHS with F16 output, using instructions compatible with FEAT_SME. - Convert SME transposed RHS packing micro-kernels to pure assembly: - kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme diff --git a/CMakeLists.txt b/CMakeLists.txt index 1146a801..a91c119a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -319,6 +319,7 @@ set(KLEIDIAI_FILES_SME2_ASM set(KLEIDIAI_FILES_SME2 ${KLEIDIAI_FILES_SME2_ASM} + kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c new file mode 100644 index 00000000..d6b11b84 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -0,0 +1,372 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) +#error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" + +#include +#include + +#include "kai/kai_common.h" + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; // Multiple of vector length +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; // Multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +// Look-up table used for int4->int8 convert +static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME(bl % kai_bl == 0); + KAI_ASSUME((k % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_n_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, size_t k, size_t bl) { + const size_t m_step = kai_get_m_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + KAI_ASSUME((m_idx % m_step) == 0); + + return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + const size_t n_step = kai_get_n_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, size_t n_idx, size_t dst_stride) { + const size_t m_step = kai_get_m_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t n_step = kai_get_n_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + KAI_ASSUME((m_idx % m_step) == 0); + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + KAI_ASSUME(m == 1); + + KAI_UNUSED(dst_stride_row); + + if (m == 0) { + return; + } + + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + + const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(); + + const float* lhs_params = (const float*)((const int8_t*)lhs_packed + (mr * bl)); + const size_t rhs_params_offset = nr * (bl / kai_num_bytes_recip_qvalue_rhs); + const size_t rhs_bias_offset = rhs_packed_stride - nr * kai_num_bytes_bias; + __asm__ volatile( + // Switch to streaming mode with ZA enabling + " .inst 0xd503477f // smstart \n" + + " ptrue p2.b, all \n" + " .inst 0x25607810 // ptrue pn8.h \n" + + // Initialize ZT0 (Lookup table) + " mov x9, %[lut] \n" + " .inst 0xe11f8120 // ldr zt0, [x9] \n" + + // Initialize the RHS packed and params pointers + " mov x10, %[rhs_packed] \n" + + // Initialize the DST pointer + " mov x5, %[dst] \n" + + // Load the clamp values + " dup z30.s, %w[scalar_min] \n" + " dup z31.s, %w[scalar_max] \n" + + // Iterate over n (x0) + // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) + " mov x4, #0\n" + " mov x17, %[n] \n" + " .inst 0x25716491 // whilelt pn9.h, x4, x17, VLx4 \n" + + " b.none 5f // .LOOP_N_END%= \n" + + " 1: // .LOOP_N_START%=: \n" + + // Initialize the LHS packed and params pointers + " mov x2, %[lhs_packed] \n" + " mov x3, %[lhs_params] \n" + " mov x0, x10 \n" + + // Initialize the 4xVL-32bit accumulators to zero + " dup z24.s, #0 \n" + " dup z25.s, #0 \n" + " dup z26.s, #0 \n" + " dup z27.s, #0 \n" + + // Initialize the vector selector for ZA array + " mov w8, #0 \n" + + // Iterate over all K values + // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) + " mov x6, #0 \n" + " whilelt p1.s, x6, %[k] \n" + " b.none 4f // .LOOP_K_END%= \n" + + " 2: // .LOOP_K_START%=: \n" + // Zeroing of inner accumulation array + " .inst 0xc00800ff // zero {za} \n" + + // Iterate over all values in the block + // k_blk_idx = bl + // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 16} + + "mov x13, %[bl] \n" + + "3: // .LOOP_BL_START%=: \n" + // Load the LHS (int8) quantized values + // Load contiguous 16 bytes and replicate. + // For GeMV, we do not interleave the LHS M rows. + " ld1rqb { z0.b }, p2/z , [x2] \n" + " add x2, x2, #16 \n" + + // -- First half + // Load the RHS (int4) quantized values + " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0] \n" + + // Increment the RHS pointer + " addvl x0, x0, #4 \n" + + // Convert Int4 -> Int8 + " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" + " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" + " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" + " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" + + // SDOT indexed + " .inst 0xc15090a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] \n" + " .inst 0xc1509520 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] \n" + + // -- Second half + + // Load the RHS (int4) quantized values + " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0]\n" + + // Increment the RHS pointer + " addvl x0, x0, #4 \n" + + // Convert Int4 -> Int8 + " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" + " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" + " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" + " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" + + // SDOT indexed + " .inst 0xc15098a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] \n" + " .inst 0xc1509d20 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] \n" + + // Decrement the block loop index + "subs x13, x13, #16 \n" + + "b.gt 3b // .LOOP_BL_START%= \n" + + // === End of the block loop === + + // Load Z registers with intermediate values from ZA array + " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n" + + // Load 1 fp32 LHS sum and scale value and replicate for VL + " ld1rw z1.s, p2/z, [x2] \n" // sum + " ld1rw z2.s, p2/z, [x2, #4] \n" // scale + + // Increment the LHS param pointer by 8 (2 x sizeof(fp32)) + " add x2, x2, #8 \n" + + // Load 4xVL-32bit (fp32) RHS zp and scales. + // If VL=512bit, we load 64 fp32 values, which is equal to the number of output columns (n_step) processed + " .inst 0xa040c004 // ld1w { z4.s - z7.s }, pn8/z, [x0]\n" // zp + " .inst 0xa041c008 // ld1w { z8.s - z11.s }, pn8/z, [x0, #0x4, mul vl ]\n" // scale + + // Increment the RHS pointer + " addvl x0, x0, #8 \n" + + // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * + // lhs_sum rhs zero point * lhs row sum + " fmla z24.s, p2/m, z4.s, z1.s \n" + " fmla z25.s, p2/m, z5.s, z1.s \n" + " fmla z26.s, p2/m, z6.s, z1.s \n" + " fmla z27.s, p2/m, z7.s, z1.s \n" + + // lhs scaling factor * rhs scaling factor + " fmul z8.s, z8.s, z2.s \n" + " fmul z9.s, z9.s, z2.s \n" + " fmul z10.s, z10.s, z2.s \n" + " fmul z11.s, z11.s, z2.s \n" + + // Convert from int32 to float32 + " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n" + + // Multiply the intermediate results by LHS_SCALE x RHS_SCALE + // and store in the main floating-point accumulator + " fmla z24.s, p2/m, z16.s, z8.s \n" + " fmla z25.s, p2/m, z17.s, z9.s \n" + " fmla z26.s, p2/m, z18.s, z10.s \n" + " fmla z27.s, p2/m, z19.s, z11.s \n" + + // Increment the number of K values processed and + // go to the next block + " add x6, x6, %[bl] \n" + " whilelt p1.s, x6, %[k] \n" + " b.first 2b // .LOOP_K_START%= \n" + " 4: //.LOOP_K_END%=: \n" + + // Load bias + " .inst 0xa040c014 // ld1w { z20.s - z23.s }, pn8/z, [x0]\n " + + // Add bias + " fadd z24.s, p2/m, z24.s, z20.s \n" + " fadd z25.s, p2/m, z25.s, z21.s \n" + " fadd z26.s, p2/m, z26.s, z22.s \n" + " fadd z27.s, p2/m, z27.s, z23.s \n" + + // Clamp + " .inst 0xc1bfcbd8 // fclamp { z24.s - z27.s }, z30.s, z31.s \n" + + // Convert to FP16 + " .inst 0xc120e31c // fcvt z28.h, {z24.s- z25.s} \n" + " .inst 0xc120e35d // fcvt z29.h, {z26.s- z27.s} \n" + + // Store the results into memory + " .inst 0xa06024bc // st1h { z28.h - z29.h }, pn9, [x5] \n" + " incb x4, all \n" + " addvl x5, x5, #2 \n" + + // Update the rhs pointers + " add x10, x10, %[rhs_packed_stride] \n" + + " .inst 0x25716491 // whilelt pn9.h, x4, %[n], VLx4 \n" + + " b.first 1b // .LOOP_N_START%= \n" + + " 5: // .LOOP_N_END%=: \n" + + // Exit streaming mode + " .inst 0xd503467f //smstop \n" + : + : [lut] "r"(lut), [dst] "r"(dst), [rhs_packed] "r"(rhs_packed), [rhs_params] "r"(rhs_params_offset), + [lhs_packed] "r"(lhs_packed), [lhs_params] "r"(lhs_params), [rhs_packed_stride] "r"(rhs_packed_stride), + [rhs_bias] "r"(rhs_bias_offset), [scalar_min] "r"(scalar_min), [scalar_max] "r"(scalar_max), + [n] "r"((int64_t)n), [k] "r"(k), [bl] "r"(bl) + : "p2", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", + "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", + "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x1", + "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x13", "x17", "memory", "cc"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h new file mode 100644 index 00000000..7feea8ad --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h @@ -0,0 +1,145 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (32) quantization (qsi8d32) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] bl Block length. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (32) quantization (qsi8d32) and packed +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = 1 x 4 VL (Vector Length) +/// +/// Instruction used: SME2 (sdot) +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus -- GitLab From 8155ae4814f592a899e3852697ddf59a084a0a0e Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 18 Jul 2025 15:03:33 +0100 Subject: [PATCH 2/3] SME Matmul Micro-kernel F16 <- (QSI8D32) LHS x (QAI4C32) RHS * Matrix multiplication (MxN) micro-kernels to compute the matrix multiplication of dynamically quantized symmetric signed 8-bit integer with per-block quantization (QSI8D32) LHS matrix and quantized asymmetric 4-bit signed integer with per-block quantization (QAI4C32) RHS matrix and the accumulation of the result into a half-precision (F16) output, optimized for SME2 technology Signed-off-by: Anitha Raj --- CMakeLists.txt | 2 + ...32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 495 ++++++++++++++++++ ...32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h | 148 ++++++ 3 files changed, 645 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a91c119a..04a44fb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,6 +327,8 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c + kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c ) add_library(kleidiai) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c new file mode 100644 index 00000000..9cabe505 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -0,0 +1,495 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) +#error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +// Compute args +static const size_t kai_m_step = 1; // Multiple of vector length +static const size_t kai_n_step = 4; // Multiple of vector length +// Packing args +static const size_t kai_mr = 1; // Multiple of vector length +static const size_t kai_nr = 4; // Multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +// Look-up table used for int4->int8 convert +static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME(bl % kai_bl == 0); + KAI_ASSUME((k % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + + size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_m_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_n_step * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, size_t k, size_t bl) { + const size_t m_step = kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + + KAI_ASSUME((m_idx % m_step) == 0); + + return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + const size_t n_step = kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + const size_t m_step = kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t n_step = kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + KAI_ASSUME((m_idx % m_step) == 0); + KAI_ASSUME((n_idx % n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + + typedef struct { + size_t lhs_packed_stride; + size_t rhs_packed_stride; + size_t mr; + size_t m; + size_t n; + size_t bl; + float min; + float max; + size_t rbias; + } KernelArgs; + + const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(); + const float* lhs_params = (const float*)((const int8_t*)lhs_packed + (mr * bl)); + const float* rhs_params = (const float*)((const int8_t*)rhs_packed + nr * (bl / kai_num_bytes_recip_qvalue_rhs)); + + KernelArgs ka; + ka.mr = mr; + ka.m = m; + ka.n = n; + ka.bl = bl; + ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); + ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + ka.min = scalar_min; + ka.max = scalar_max; + ka.rbias = ka.rhs_packed_stride - nr * kai_num_bytes_bias; + + __asm__ volatile( + // Switch to streaming mode with ZA enabling + " .inst 0xd503477f // smstart \n" + + // Constants + // - SVLs + " cntw x14 \n" + // - ptrue + " ptrue p0.b, all \n" + " .inst 0x25a07810 // ptrue pn8.s \n" + + // Predicate for loading fp32 parameters + " ldr x5, [%x[args_ptr], %[offset_mr]]\n" + " lsl x5, x5, #2 \n" + " whilelt p4.b, xzr, x5 \n" + + // Initialize ZT0 (Lookup table) + " mov x6, %[lut]\n" + " .inst 0xe11f80c0 // ldr zt0, [x6] \n" + + // Initialize the RHS packes and scale pointers + " mov x16, %[rhs_packed] \n" + " mov x17, %[rhs_params] \n" + + // Load the clamp values + " ld1rw z9.s, p0/z, [%x[args_ptr], %[offset_min]] \n" + " ld1rw z10.s, p0/z, [%x[args_ptr], %[offset_max]] \n" + " ldr x4, [%x[args_ptr], %[offset_bl]]\n" + + // Iterate over n (x8) + // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) + " mov x8, #0 \n" + " ldr x0, [%x[args_ptr], %[offset_n]] \n" + " .inst 0x25606511 //whilelt pn9.h, x8, x0, VLx4 \n" + + " b.none 9f // .LOOP_N_END%= \n" + + " 1: // .LOOP_N_START%=: \n" + + // Iterate over m (x9) + // e.g. for(m_idx = 0; m_idx < m; m_idx+=m_step) + " ldr x9, [%x[args_ptr], %[offset_m]]\n" + + // Initialize the LHS packed and scale pointers + " mov x22, %[lhs_packed] \n" + " mov x23, %[lhs_params] \n" + + // Initialize the DST pointer + " mov x24, %[dst] \n" + + " 2: // .LOOP_M_START%=: \n" + + // Address offset for the left and right quantized values + " mov x20, #0 \n" + " mov x21, #0 \n" + " mov x17, x16 \n" + " mov x3, x22 \n" + + // Number of output rows to store -> min(SVLh, loop M index) + " cmp x9, x14 \n" + " csel x15, x9, x14, lo \n" + " lsl x15, x15, #2 \n" + + // Iterate over all K values + // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) + " mov x10, %[K] \n" + + // Skip processing if K=0 + " cmp x10, #0 \n" + " b.eq 8f // .LOOP_K_END%= \n" + + " 3: // .LOOP_K_START%=: \n" + + // Zeroing of ZA accumulator + " .inst 0xc00800ff // zero {za} \n" + + // Iterate over all values in the block + // k_blk_idx = bl + // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4} + " mov x11, x4 \n" + + " 4: // .LOOP_BL_START%=: \n" + + // Load right matrix row + " .inst 0xa0404222 //ld1w {z2.s - z3.s}, pn8/z, [x17] \n" + " addvl x17, x17, #2 \n" + + // Load left matrix column + " ld1h {z8.h}, p0/z, [x3, x20, lsl #1] \n" + " addvl x3, x3, #1 \n" + + // Convert Int4 -> Int8 + " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n" + " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n" + + // Outer-products + " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n" + " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n" + " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n" + " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n" + + // Decrement the block loop index + " subs x11, x11, #4 \n" + + " b.gt 4b // .LOOP_BL_START%= \n" + + // === End of the block loop === + + // Store loop index + " mov w12, #0 \n" + + // Copy destination pointer for store loop + " mov x25, x24 \n" + + // Load the fp32 sum, scaling factors for the left matrix block + // and zero points, scaling factors for the right matrix block + + " ld1b {z17.b}, p4/z, [x3] \n" // lhs sum + " ld1b {z16.b}, p4/z, [x3, #1, mul vl] \n" // lhs scale + " addvl x3, x3, #2 \n" + + " .inst 0xa040c234 // ld1w { z20.s - z23.s }, pn8/z, [x17]\n" // rhs zp + " .inst 0xa041c220 // ld1w { z0.s - z3.s }, pn8/z, [x17, #4, mul vl ]\n" // rhs scale + " addvl x17, x17, #8 \n" + + // Predicate for the selection of a scaling among the vector + " pfalse p3.b \n" + + " 5: // .LOOP_ZA%=: \n" + + // Select and replicate scaling factor for the right block + " pnext p3.s, p0, p3.s \n" + " clastb z19.s, p3, z19.s, z16.s \n" + " clastb z18.s, p3, z18.s, z17.s \n" + + // Get data from za + " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" + " add w12, w12, #4 \n" + + // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * + // lhs_sum rhs scaling factor * lhs scaling factor + " fmul z4.s, z0.s, z19.s \n" + " fmul z5.s, z1.s, z19.s \n" + " fmul z6.s, z2.s, z19.s \n" + " fmul z7.s, z3.s, z19.s \n" + + // Convert from int32 to fp32 + " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" + + " cmp x10, %[K] \n" + " b.ne 6f // .ACCUMULATE%= \n" + + // rhs zero point * lhs row sum + " fmul z24.s, z20.s, z18.s \n" + " fmul z25.s, z21.s, z18.s \n" + " fmul z26.s, z22.s, z18.s \n" + " fmul z27.s, z23.s, z18.s \n" + + // Applying combined scaling factors to processed block + " fmla z24.s, p0/m, z4.s, z28.s \n" + " fmla z25.s, p0/m, z5.s, z29.s \n" + " fmla z26.s, p0/m, z6.s, z30.s \n" + " fmla z27.s, p0/m, z7.s, z31.s \n" + + "b 7f // .STORE%= \n" + + " 6: // .ACCUMULATE%=: \n" + // Load intermediate result + + // Load acc + " ld1h z24.s , p0/z, [x25] \n" + " ld1h z25.s , p0/z, [x25, #1, MUL VL] \n" + " ld1h z26.s , p0/z, [x25, #2, MUL VL] \n" + " ld1h z27.s , p0/z, [x25, #3, MUL VL] \n" + " fcvt z24.s, p0/m, z24.h \n" + " fcvt z25.s, p0/m, z25.h \n" + " fcvt z26.s, p0/m, z26.h \n" + " fcvt z27.s, p0/m, z27.h \n" + + // rhs zero point * lhs row sum + " fmla z24.s, p0/m, z20.s, z18.s \n" + " fmla z25.s, p0/m, z21.s, z18.s \n" + " fmla z26.s, p0/m, z22.s, z18.s \n" + " fmla z27.s, p0/m, z23.s, z18.s \n" + + // Multiply the intermediate results by LHS_SCALE x RHS_SCALE + // and store in the main floating-point accumulator + " fmla z24.s, p0/m, z4.s, z28.s \n" + " fmla z25.s, p0/m, z5.s, z29.s \n" + " fmla z26.s, p0/m, z6.s, z30.s \n" + " fmla z27.s, p0/m, z7.s, z31.s \n" + + "7: // .STORE%=: \n" + + // Convert to FP16 + " .inst 0xc120e31c // fcvt z28.h, {z24.s- z25.s} \n" + " .inst 0xc120e35d // fcvt z29.h, {z26.s- z27.s} \n" + + // Store the results into memory + " .inst 0xa060273c // st1h { z28.h - z29.h }, pn9, [x25] \n" + " add x25, x25, %[stride] \n" + + " cmp x12, x15 \n" + " blt 5b // .LOOP_ZA%= \n" + + // Decrement K loop index by bl + " subs x10, x10, x4 \n" + + " b.gt 3b // .LOOP_K_START%= \n" + + " 8: // .LOOP_K_END%=: \n" + + // === End of the K loop === + + // Load bias + " ldr x5, [%x[args_ptr], %[offset_bias]]\n" + " add x5, x5, x16\n" + " .inst 0xa040c0ac // ld1w {z12.s - z15.s}, pn8/z, [x5] \n " + + "mov x12, 0\n" + + " 10: \n" // Bias loop + + // Load acc + " ld1h z24.s , p0/z, [x24] \n" + " ld1h z25.s , p0/z, [x24, #1, MUL VL] \n" + " ld1h z26.s , p0/z, [x24, #2, MUL VL] \n" + " ld1h z27.s , p0/z, [x24, #3, MUL VL] \n" + " fcvt z24.s, p0/m, z24.h \n" + " fcvt z25.s, p0/m, z25.h \n" + " fcvt z26.s, p0/m, z26.h \n" + " fcvt z27.s, p0/m, z27.h \n" + + // Add bias + " fadd z24.s, p0/m, z24.s, z12.s \n" + " fadd z25.s, p0/m, z25.s, z13.s \n" + " fadd z26.s, p0/m, z26.s, z14.s \n" + " fadd z27.s, p0/m, z27.s, z15.s \n" + + // Clamp + " .inst 0xc1aac938 //fclamp { z24.s - z27.s }, z9.s, z10.s \n" + // Convert to FP16 + " .inst 0xc120e31c // fcvt z28.h, {z24.s- z25.s} \n" + " .inst 0xc120e35d // fcvt z29.h, {z26.s- z27.s} \n" + + // Store the results into memory + " .inst 0xa060271c // st1h { z28.h - z29.h }, pn9, [x24] \n" + + " add x24, x24, %[stride] \n" + " add x12, x12, #4\n" + " cmp x12, x15 \n" + " blt 10b // Bias loop \n" + + " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" + + // Increment pointer to the quantized values of the lhs matrix + " add x22, x22, x5\n" + + // Increment pointer to the scaling factors of the lhs matrix + " add x23, x23, x5 \n" + + // Update destination pointer + " mov x24, x25 \n" + + // Decrement M loop index + " decw x9, all \n" + + " cmp x9, #0 \n" + " b.gt 2b // .LOOP_M_START%= \n" + + // === End of M loop === + + // Increment output pointer + " incb %[dst], all, mul #2 \n" + + " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n" + + " add x16, x16, x5 \n" + " add x17, x17, x5 \n" + + // Increment N loop index + " incb x8, all \n" + + " .inst 0x25606511 //whilelt pn9.h, x8, x0, VLx4 \n" + + " b.first 1b // .LOOP_N_START%= \n" + + " 9: // .LOOP_N_END%=: \n" + + // === End of N loop === + + // Exit streaming mode + " .inst 0xd503467f // smstop \n" + : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_params] "+r"(rhs_params) + : [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_params] "r"(lhs_params), [stride] "r"(dst_stride_row), + [lut] "r"(lut), [args_ptr] "r"(&ka), [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)), + [offset_m] "I"(offsetof(KernelArgs, m)), [offset_n] "I"(offsetof(KernelArgs, n)), + [offset_min] "I"(offsetof(KernelArgs, min)), [offset_max] "I"(offsetof(KernelArgs, max)), + [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr)), + [offset_bias] "I"(offsetof(KernelArgs, rbias)), [offset_bl] "I"(offsetof(KernelArgs, bl)) + : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1", + "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", + "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x3", "x4", + "x5", "x6", "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", + "x25", "memory", "cc"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h new file mode 100644 index 00000000..27fddb59 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h @@ -0,0 +1,148 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32ps1s0_qau4c32s1s0_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) +/// values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) and packed +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = 1 VL x 4 VL (Vector Length) +/// +/// Instruction used: SME2 (MOPA) +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus -- GitLab From 32f1575dfe19215cb841a38449b9dccd3b0e391e Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 18 Jul 2025 17:17:08 +0100 Subject: [PATCH 3/3] Fix build warnings and update packing parameters Signed-off-by: Anitha Raj --- ...32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c | 92 +++++++++---------- ...qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c | 4 +- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c index 9cabe505..3356624e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -20,8 +20,8 @@ static const size_t kai_n_step = 4; // Multiple of vector length // Packing args static const size_t kai_mr = 1; // Multiple of vector length static const size_t kai_nr = 4; // Multiple of vector length -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; @@ -187,14 +187,14 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( __asm__ volatile( // Switch to streaming mode with ZA enabling - " .inst 0xd503477f // smstart \n" + " .inst 0xd503477f\n" // smstart // Constants // - SVLs " cntw x14 \n" // - ptrue " ptrue p0.b, all \n" - " .inst 0x25a07810 // ptrue pn8.s \n" + " .inst 0x25a07810 \n" // ptrue pn8.s // Predicate for loading fp32 parameters " ldr x5, [%x[args_ptr], %[offset_mr]]\n" @@ -203,7 +203,7 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( // Initialize ZT0 (Lookup table) " mov x6, %[lut]\n" - " .inst 0xe11f80c0 // ldr zt0, [x6] \n" + " .inst 0xe11f80c0 \n" // ldr zt0, [x6] // Initialize the RHS packes and scale pointers " mov x16, %[rhs_packed] \n" @@ -218,11 +218,11 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) " mov x8, #0 \n" " ldr x0, [%x[args_ptr], %[offset_n]] \n" - " .inst 0x25606511 //whilelt pn9.h, x8, x0, VLx4 \n" + " .inst 0x25606511 \n" // whilelt pn9.h, x8, x0, VLx4 - " b.none 9f // .LOOP_N_END%= \n" + " b.none 9f\n" // .LOOP_N_END%= - " 1: // .LOOP_N_START%=: \n" + " 1: \n" // .LOOP_N_START%=: // Iterate over m (x9) // e.g. for(m_idx = 0; m_idx < m; m_idx+=m_step) @@ -235,7 +235,7 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( // Initialize the DST pointer " mov x24, %[dst] \n" - " 2: // .LOOP_M_START%=: \n" + " 2: \n" // .LOOP_M_START%=: // Address offset for the left and right quantized values " mov x20, #0 \n" @@ -254,22 +254,22 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( // Skip processing if K=0 " cmp x10, #0 \n" - " b.eq 8f // .LOOP_K_END%= \n" + " b.eq 8f\n" // .LOOP_K_END%= - " 3: // .LOOP_K_START%=: \n" + " 3: \n" // .LOOP_K_START%=: // Zeroing of ZA accumulator - " .inst 0xc00800ff // zero {za} \n" + " .inst 0xc00800ff \n" // zero {za} // Iterate over all values in the block // k_blk_idx = bl // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4} " mov x11, x4 \n" - " 4: // .LOOP_BL_START%=: \n" + " 4: \n" // .LOOP_BL_START%=: // Load right matrix row - " .inst 0xa0404222 //ld1w {z2.s - z3.s}, pn8/z, [x17] \n" + " .inst 0xa0404222 \n" // ld1w {z2.s - z3.s}, pn8/z, [x17] " addvl x17, x17, #2 \n" // Load left matrix column @@ -277,19 +277,19 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " addvl x3, x3, #1 \n" // Convert Int4 -> Int8 - " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n" - " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n" + " .inst 0xc08a4044 \n" // luti4 {z4.b - z5.b}, zt0, z2[0] + " .inst 0xc08a4066 \n" // luti4 {z6.b - z7.b}, zt0, z3[0] // Outer-products - " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n" - " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n" - " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n" - " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n" + " .inst 0xa0840100 \n" // smopa za0.s, p0/m, p0/m, z8.b, z4.b + " .inst 0xa0850101 \n" // smopa za1.s, p0/m, p0/m, z8.b, z5.b + " .inst 0xa0860102 \n" // smopa za2.s, p0/m, p0/m, z8.b, z6.b + " .inst 0xa0870103 \n" // smopa za3.s, p0/m, p0/m, z8.b, z7.b // Decrement the block loop index " subs x11, x11, #4 \n" - " b.gt 4b // .LOOP_BL_START%= \n" + " b.gt 4b \n" // .LOOP_BL_START%= // === End of the block loop === @@ -306,14 +306,14 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " ld1b {z16.b}, p4/z, [x3, #1, mul vl] \n" // lhs scale " addvl x3, x3, #2 \n" - " .inst 0xa040c234 // ld1w { z20.s - z23.s }, pn8/z, [x17]\n" // rhs zp - " .inst 0xa041c220 // ld1w { z0.s - z3.s }, pn8/z, [x17, #4, mul vl ]\n" // rhs scale + " .inst 0xa040c234 \n" // ld1w { z20.s - z23.s }, pn8/z, [x17] // rhs zp + " .inst 0xa041c220 \n" // ld1w { z0.s - z3.s }, pn8/z, [x17, #4, mul vl ] // rhs scale " addvl x17, x17, #8 \n" // Predicate for the selection of a scaling among the vector " pfalse p3.b \n" - " 5: // .LOOP_ZA%=: \n" + " 5: \n" // .LOOP_ZA%=: \n" // Select and replicate scaling factor for the right block " pnext p3.s, p0, p3.s \n" @@ -321,7 +321,7 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " clastb z18.s, p3, z18.s, z17.s \n" // Get data from za - " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" + " .inst 0xc006041c \n" // mova {z28.b-z31.b}, za0h.b[w12, 0:3] " add w12, w12, #4 \n" // za now contains lhs * rhs, this needs to be updated to (lhs * rhs) * (lhs_scale * rhs_scale )+ rhs_zp * @@ -332,10 +332,10 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " fmul z7.s, z3.s, z19.s \n" // Convert from int32 to fp32 - " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" + " .inst 0xc132e39c \n" // scvtf {z28.s-z31.s}, {z28.s-z31.s} " cmp x10, %[K] \n" - " b.ne 6f // .ACCUMULATE%= \n" + " b.ne 6f \n" // .ACCUMULATE%= // rhs zero point * lhs row sum " fmul z24.s, z20.s, z18.s \n" @@ -349,9 +349,9 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " fmla z26.s, p0/m, z6.s, z30.s \n" " fmla z27.s, p0/m, z7.s, z31.s \n" - "b 7f // .STORE%= \n" + "b 7f \n" // .STORE%= - " 6: // .ACCUMULATE%=: \n" + " 6: \n" // .ACCUMULATE%=: // Load intermediate result // Load acc @@ -377,32 +377,32 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " fmla z26.s, p0/m, z6.s, z30.s \n" " fmla z27.s, p0/m, z7.s, z31.s \n" - "7: // .STORE%=: \n" + "7: \n" // .STORE%=: // Convert to FP16 - " .inst 0xc120e31c // fcvt z28.h, {z24.s- z25.s} \n" - " .inst 0xc120e35d // fcvt z29.h, {z26.s- z27.s} \n" + " .inst 0xc120e31c \n" // fcvt z28.h, {z24.s- z25.s} + " .inst 0xc120e35d \n" // fcvt z29.h, {z26.s- z27.s} // Store the results into memory - " .inst 0xa060273c // st1h { z28.h - z29.h }, pn9, [x25] \n" + " .inst 0xa060273c \n" // st1h { z28.h - z29.h }, pn9, [x25] " add x25, x25, %[stride] \n" " cmp x12, x15 \n" - " blt 5b // .LOOP_ZA%= \n" + " blt 5b \n" // .LOOP_ZA%= // Decrement K loop index by bl " subs x10, x10, x4 \n" - " b.gt 3b // .LOOP_K_START%= \n" + " b.gt 3b \n" // .LOOP_K_START - " 8: // .LOOP_K_END%=: \n" + " 8: \n" // .LOOP_K_END%=: // === End of the K loop === // Load bias " ldr x5, [%x[args_ptr], %[offset_bias]]\n" " add x5, x5, x16\n" - " .inst 0xa040c0ac // ld1w {z12.s - z15.s}, pn8/z, [x5] \n " + " .inst 0xa040c0ac \n" // ld1w {z12.s - z15.s}, pn8/z, [x5] "mov x12, 0\n" @@ -425,18 +425,18 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " fadd z27.s, p0/m, z27.s, z15.s \n" // Clamp - " .inst 0xc1aac938 //fclamp { z24.s - z27.s }, z9.s, z10.s \n" + " .inst 0xc1aac938 \n" // fclamp { z24.s - z27.s }, z9.s, z10.s // Convert to FP16 - " .inst 0xc120e31c // fcvt z28.h, {z24.s- z25.s} \n" - " .inst 0xc120e35d // fcvt z29.h, {z26.s- z27.s} \n" + " .inst 0xc120e31c \n" // fcvt z28.h, {z24.s- z25.s} + " .inst 0xc120e35d \n" // fcvt z29.h, {z26.s- z27.s} // Store the results into memory - " .inst 0xa060271c // st1h { z28.h - z29.h }, pn9, [x24] \n" + " .inst 0xa060271c \n" // st1h { z28.h - z29.h }, pn9, [x24] " add x24, x24, %[stride] \n" " add x12, x12, #4\n" " cmp x12, x15 \n" - " blt 10b // Bias loop \n" + " blt 10b \n" // Bias loop " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" @@ -453,7 +453,7 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( " decw x9, all \n" " cmp x9, #0 \n" - " b.gt 2b // .LOOP_M_START%= \n" + " b.gt 2b \n" // .LOOP_M_START%= // === End of M loop === @@ -468,11 +468,11 @@ void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa( // Increment N loop index " incb x8, all \n" - " .inst 0x25606511 //whilelt pn9.h, x8, x0, VLx4 \n" + " .inst 0x25606511 \n" // whilelt pn9.h, x8, x0, VLx4 - " b.first 1b // .LOOP_N_START%= \n" + " b.first 1b \n" // .LOOP_N_START%= - " 9: // .LOOP_N_END%=: \n" + " 9: \n" // .LOOP_N_END%=: // === End of N loop === diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c index d6b11b84..4654c7a7 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c @@ -20,8 +20,8 @@ static const size_t kai_n_step = 4; // Multiple of vector length // Packing args static const size_t kai_mr = 1; static const size_t kai_nr = 4; // Multiple of vector length -static const size_t kai_kr = 4; -static const size_t kai_sr = 1; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; -- GitLab