From c6787018d3476b790c47ff0aee973a6709072c6d Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 20 Jun 2025 14:17:44 +0100 Subject: [PATCH] Matmul Micro-kernel(1xN) F32/F16 <- (QSI8D32) LHS x (QAI4C32) RHS * Matrix multiplication (1xN) micro-kernels to compute the matrix multiplication of dynamically quantized symmetric signed 8-bit integer with per-block quantization (QSI8D32) LHS matrix and quantized asymmetric 4-bit signed integer with per-block quantization (QAI4C32) RHS matrix and the accumulation of the result into a single-precision (F32) and half-precision (F16) output, optimized for FEAT_DotProd and packing parameter kr = 8. Signed-off-by: Anitha Raj --- CHANGELOG.md | 4 + CMakeLists.txt | 4 + kai/ukernels/matmul/BUILD.bazel | 2 + ...qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c | 179 ++++++++++++++++++ ...qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h | 150 +++++++++++++++ ...d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S | 155 +++++++++++++++ ...qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c | 178 +++++++++++++++++ ...qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h | 150 +++++++++++++++ ...d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S | 154 +++++++++++++++ ...atmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp | 7 +- ...atmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp | 7 +- 11 files changed, 986 insertions(+), 4 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S diff --git a/CHANGELOG.md b/CHANGELOG.md index c96364c8..6ebf228d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New Advanced SIMD micro-kernels: + - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_DotProd. + - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. + ## v1.10.0 - Convert SME and SME2 imatmul micro-kernels to use pure assembly, and add MSVC support. Affects: diff --git a/CMakeLists.txt b/CMakeLists.txt index 66a34548..a7d6185c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,6 +115,8 @@ set(KLEIDIAI_FILES_NEON_FP16_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S @@ -185,6 +187,8 @@ set(KLEIDIAI_FILES_NEON_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 1abee3c7..038a4c01 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -80,6 +80,7 @@ FP16_DOTPROD_KERNELS_ASM = [ "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", + "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", ] @@ -115,6 +116,7 @@ DOTPROD_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod", + "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", ] diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c new file mode 100644 index 00000000..38d27f8c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c @@ -0,0 +1,179 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + void* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + return kai_mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, size_t k, size_t bl) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + const size_t num_subblocks = bl / kai_bl; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h new file mode 100644 index 00000000..57a15d7b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h @@ -0,0 +1,150 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) +/// values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) and packed. +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..87208024 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S @@ -0,0 +1,155 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x21, #0x8 + movi v28.16b, #0xf0 + mov x15, #0x20 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + madd x15, x14, x15, x21 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + mul x15, x13, x15 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + movi v27.16b, #0x0 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v26.4s, #0x0 + mov x20, x14 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q25, [x25, #0x0] + ldr q24, [x22, #0x0] + subs x20, x20, #0x1 + ldr q23, [x25, #0x10] + ldr q22, [x25, #0x20] + ldr q21, [x25, #0x30] + ldr q20, [x22, #0x10] + add x25, x25, #0x40 + add x22, x22, #0x20 + shl v19.16b, v25.16b, #0x4 + and v25.16b, v25.16b, v28.16b + shl v18.16b, v23.16b, #0x4 + shl v17.16b, v22.16b, #0x4 + shl v16.16b, v21.16b, #0x4 + and v23.16b, v23.16b, v28.16b + KAI_ASM_INST(0x4f98e27a) // sdot v26.4s, v19.16b, v24.4b[0] + and v22.16b, v22.16b, v28.16b + and v21.16b, v21.16b, v28.16b + KAI_ASM_INST(0x4fb8e25a) // sdot v26.4s, v18.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea3a) // sdot v26.4s, v17.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea1a) // sdot v26.4s, v16.16b, v24.4b[3] + KAI_ASM_INST(0x4f94e33a) // sdot v26.4s, v25.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e2fa) // sdot v26.4s, v23.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eada) // sdot v26.4s, v22.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eaba) // sdot v26.4s, v21.16b, v20.4b[3] + bgt label_4 + ldr q19, [x25, #0x0] + ld1r { v18.4s }, [x22] + add x22, x22, #0x4 + scvtf v26.4s, v26.4s + ld1r { v17.4s }, [x22] + ldr q16, [x25, #0x10] + sub x21, x21, #0x1 + add x22, x22, #0x4 + add x25, x25, #0x20 + fmla v27.4s, v19.4s, v18.s[0] + fmul v16.4s, v16.4s, v17.4s + fmla v27.4s, v26.4s, v16.4s + cbnz x21, label_3 + ldr q18, [x25, #0x0] + ld1r { v17.4s }, [x27] + add x20, x27, #0x4 + cmp x24, #0x4 + ld1r { v16.4s }, [x20] + add x25, x25, #0x10 + fadd v27.4s, v27.4s, v18.4s + fmax v27.4s, v27.4s, v17.4s + fmin v27.4s, v27.4s, v16.4s + fcvtn v16.4h, v27.4s + blt label_5 + str d16, [x9, #0x0] + b label_8 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #1, label_6 + st1 { v16.s }[0], [x20], #0x4 + tbz x24, #0, label_7 + st1 { v16.h }[2], [x20] + b label_7 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_7) // Output block 0: Done +KAI_ASM_LABEL(label_8) // Stores done + subs x24, x24, #0x4 + add x9, x9, #0x8 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c new file mode 100644 index 00000000..b2580e81 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.c @@ -0,0 +1,178 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) +#error "Dotprod extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" + +#include + +#include "kai/kai_common.h" + +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_bl = 32; + +inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { + return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs; +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = + (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { + return kai_mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, size_t k, size_t bl) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k, bl); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + KAI_ASSUME(m == 1); + + if (m == 0) { + return; + } + const size_t num_subblocks = bl / kai_bl; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h new file mode 100644 index 00000000..7d702fd3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h @@ -0,0 +1,150 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon to dynamically quantize and pack the LHS matrix in a single +/// step. +/// -# @ref kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon to pack the RHS NxK matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) +/// values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-block (multiple of 32) quantization (qsi8d32) and packed. +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-block (multiple of 32) quantization (qai4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..ff2b29eb --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod_asm.S @@ -0,0 +1,154 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x21, #0x8 + movi v28.16b, #0xf0 + mov x15, #0x20 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + madd x15, x14, x15, x21 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + mul x15, x13, x15 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + movi v27.16b, #0x0 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v26.4s, #0x0 + mov x20, x14 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q25, [x25, #0x0] + ldr q24, [x22, #0x0] + subs x20, x20, #0x1 + ldr q23, [x25, #0x10] + ldr q22, [x25, #0x20] + ldr q21, [x25, #0x30] + ldr q20, [x22, #0x10] + add x25, x25, #0x40 + add x22, x22, #0x20 + shl v19.16b, v25.16b, #0x4 + and v25.16b, v25.16b, v28.16b + shl v18.16b, v23.16b, #0x4 + shl v17.16b, v22.16b, #0x4 + shl v16.16b, v21.16b, #0x4 + and v23.16b, v23.16b, v28.16b + KAI_ASM_INST(0x4f98e27a) // sdot v26.4s, v19.16b, v24.4b[0] + and v22.16b, v22.16b, v28.16b + and v21.16b, v21.16b, v28.16b + KAI_ASM_INST(0x4fb8e25a) // sdot v26.4s, v18.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ea3a) // sdot v26.4s, v17.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8ea1a) // sdot v26.4s, v16.16b, v24.4b[3] + KAI_ASM_INST(0x4f94e33a) // sdot v26.4s, v25.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e2fa) // sdot v26.4s, v23.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eada) // sdot v26.4s, v22.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eaba) // sdot v26.4s, v21.16b, v20.4b[3] + bgt label_4 + ldr q19, [x25, #0x0] + ld1r { v18.4s }, [x22] + add x22, x22, #0x4 + scvtf v26.4s, v26.4s + ld1r { v17.4s }, [x22] + ldr q16, [x25, #0x10] + sub x21, x21, #0x1 + add x22, x22, #0x4 + add x25, x25, #0x20 + fmla v27.4s, v19.4s, v18.s[0] + fmul v16.4s, v16.4s, v17.4s + fmla v27.4s, v26.4s, v16.4s + cbnz x21, label_3 + ldr q18, [x25, #0x0] + ld1r { v17.4s }, [x27] + add x20, x27, #0x4 + cmp x24, #0x4 + ld1r { v16.4s }, [x20] + add x25, x25, #0x10 + fadd v27.4s, v27.4s, v18.4s + fmax v27.4s, v27.4s, v17.4s + fmin v27.4s, v27.4s, v16.4s + blt label_5 + str q27, [x9, #0x0] + b label_8 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #1, label_6 + st1 { v27.d }[0], [x20], #0x8 + tbz x24, #0, label_7 + st1 { v27.s }[2], [x20] + b label_7 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 + st1 { v27.s }[0], [x20] +KAI_ASM_LABEL(label_7) // Output block 0: Done +KAI_ASM_LABEL(label_8) // Stores done + subs x24, x24, #0x4 + add x9, x9, #0x10 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp b/test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp index 72a535c5..1a1f3acb 100644 --- a/test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp +++ b/test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp @@ -14,6 +14,7 @@ #include #include +#include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" @@ -39,14 +40,16 @@ namespace kai::test { -static const std::array, 3> +static const std::array, 4> variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p = { {{UKERNEL_MATMUL_VARIANT(clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod), "kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, {UKERNEL_MATMUL_VARIANT(clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm), "kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", cpu_has_i8mm_and_fp16}, {UKERNEL_MATMUL_VARIANT(clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod), - "kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", cpu_has_dotprod_and_fp16}}}; + "kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod), + "kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}}}; class MatMulTest_f16_qsi8d32p_qai4c32p : public ::testing::TestWithParam {}; diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp index a36ff48b..fc703432 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp @@ -14,6 +14,7 @@ #include #include +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" @@ -37,14 +38,16 @@ namespace kai::test { -static const std::array, 3> +static const std::array, 4> variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p = { {{UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod), "kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod}, {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm), "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", cpu_has_i8mm}, {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod), - "kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", cpu_has_dotprod}}}; + "kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod), + "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod", cpu_has_dotprod}}}; class MatMulTest_f32_qsi8d32p_qai4c32p : public ::testing::TestWithParam {}; -- GitLab