From 1677dcd3883d449f52351a6ac238f2e69b36de44 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 14 Apr 2025 12:29:58 +0100 Subject: [PATCH 1/4] add support for f16 with asymmetric int8 LHS, symmetric int8 RHS Signed-off-by: Evie Wright --- CMakeLists.txt | 9 + ...6_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c | 163 ++++ ...6_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h | 138 ++++ ...i8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_asm.S | 141 ++++ ...6_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c | 163 ++++ ...6_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h | 138 ++++ ...i8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S | 144 ++++ ..._qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c | 163 ++++ ..._qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h | 138 ++++ ...8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S | 713 ++++++++++++++++++ ...f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c | 163 ++++ ...f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h | 138 ++++ ...qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_asm.S | 653 ++++++++++++++++ ...tmul_clamp_f16_qai8dxp_qsi8cxp_interface.h | 53 ++ test/reference/matmul.cpp | 9 + .../matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp | 232 ++++++ 16 files changed, 3158 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h create mode 100644 test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 32fda7ee..a792d06a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,6 +111,12 @@ set(KLEIDIAI_FILES_NEON_FP16_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_FP16_DOTPROD @@ -121,6 +127,8 @@ set(KLEIDIAI_FILES_NEON_FP16_I8MM_ASM kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c ) set(KLEIDIAI_FILES_NEON_FP16_I8MM @@ -378,6 +386,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp + test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp test/tests/matmul_test.cpp ) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c new file mode 100644 index 00000000..775cb7c3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c @@ -0,0 +1,163 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +static const size_t kai_num_bytes_qvalue_rhs = 1; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal * kai_num_bytes_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + + if (m == 0) { + return; + } + const size_t num_blocks = kai_get_k_roundedup(k) / kai_k_multiple_of; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h new file mode 100644 index 00000000..7f56d13b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h @@ -0,0 +1,138 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..f86dcf85 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod_asm.S @@ -0,0 +1,141 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x13, #0x20 + mov x21, #0x8 + ldr x12, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x11, [x0, #0x8] + ldr x10, [x0, #0x10] + ldr x9, [x0, #0x30] + ldr x28, [x0, #0x0] + ldr x27, [x0, #0x20] + madd x13, x12, x13, x21 + ldr x26, [x0, #0x18] + mov x25, x20 +KAI_ASM_LABEL(label_1) // Row loop + mov x24, x10 + mov x23, x9 + add x22, x28, x27 +KAI_ASM_LABEL(label_2) // Column loop + mov x21, x11 + movi v25.4s, #0x0 + mov x20, x12 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q16, [x24, #0x0] + ldr q24, [x21, #0x0] + subs x20, x20, #0x1 + ldr q23, [x24, #0x10] + ldr q22, [x24, #0x20] + ldr q21, [x24, #0x30] + ldr q20, [x24, #0x40] + ldr q19, [x21, #0x10] + ldr q18, [x24, #0x50] + KAI_ASM_INST(0x4f98e219) // sdot v25.4s, v16.16b, v24.4b[0] + add x21, x21, #0x20 + ldr q17, [x24, #0x60] + ldr q16, [x24, #0x70] + add x24, x24, #0x80 + KAI_ASM_INST(0x4fb8e2f9) // sdot v25.4s, v23.16b, v24.4b[1] + KAI_ASM_INST(0x4f98ead9) // sdot v25.4s, v22.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8eab9) // sdot v25.4s, v21.16b, v24.4b[3] + KAI_ASM_INST(0x4f93e299) // sdot v25.4s, v20.16b, v19.4b[0] + KAI_ASM_INST(0x4fb3e259) // sdot v25.4s, v18.16b, v19.4b[1] + KAI_ASM_INST(0x4f93ea39) // sdot v25.4s, v17.16b, v19.4b[2] + KAI_ASM_INST(0x4fb3ea19) // sdot v25.4s, v16.16b, v19.4b[3] + bgt label_3 + ldr q22, [x24, #0x0] + ld1r { v21.4s }, [x21] + add x21, x21, #0x4 + add x20, x26, #0x4 + ld1r { v20.4s }, [x21] + ldr q16, [x24, #0x10] + cmp x23, #0x4 + ldr q19, [x24, #0x20] + ld1r { v18.4s }, [x26] + add x24, x24, #0x30 + ld1r { v17.4s }, [x20] + mla v25.4s, v22.4s, v21.s[0] + fmul v16.4s, v16.4s, v20.4s + scvtf v25.4s, v25.4s + fmul v16.4s, v25.4s, v16.4s + fadd v16.4s, v16.4s, v19.4s + fmax v16.4s, v16.4s, v18.4s + fmin v16.4s, v16.4s, v17.4s + fcvtn v16.4h, v16.4s + blt label_4 + str d16, [x28, #0x0] + b label_7 +KAI_ASM_LABEL(label_4) // Partial output + mov x20, x28 + tbz x23, #1, label_5 + st1 { v16.s }[0], [x20], #0x4 + tbz x23, #0, label_6 + st1 { v16.h }[2], [x20] + b label_6 +KAI_ASM_LABEL(label_5) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_6) // Output block 0: Done +KAI_ASM_LABEL(label_7) // Stores done + subs x23, x23, #0x4 + add x28, x28, #0x8 + bgt label_2 + subs x25, x25, #0x1 + add x11, x11, x13 + mov x28, x22 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c new file mode 100644 index 00000000..75ec292f --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c @@ -0,0 +1,163 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +static const size_t kai_num_bytes_qvalue_rhs = 1; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal * kai_num_bytes_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + + if (m == 0) { + return; + } + const size_t num_blocks = kai_get_k_roundedup(k) / kai_k_multiple_of; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h new file mode 100644 index 00000000..660a3792 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h @@ -0,0 +1,138 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..4dad5225 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod_asm.S @@ -0,0 +1,144 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x13, #0x20 + mov x21, #0x8 + ldr x12, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x11, [x0, #0x8] + ldr x10, [x0, #0x10] + ldr x9, [x0, #0x30] + ldr x28, [x0, #0x0] + ldr x27, [x0, #0x20] + madd x13, x12, x13, x21 + ldr x26, [x0, #0x18] + mov x25, x20 +KAI_ASM_LABEL(label_1) // Row loop + mov x24, x10 + mov x23, x9 + add x22, x28, x27 +KAI_ASM_LABEL(label_2) // Column loop + mov x21, x11 + movi v27.4s, #0x0 + movi v26.4s, #0x0 + mov x20, x12 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q25, [x24, #0x0] + ldr q17, [x24, #0x10] + subs x20, x20, #0x1 + ld1r { v16.2d }, [x21], #0x8 + ldr q24, [x24, #0x20] + ldr q23, [x24, #0x30] + ldr q22, [x24, #0x40] + ld1r { v21.2d }, [x21], #0x8 + ldr q20, [x24, #0x50] + ld1r { v19.2d }, [x21], #0x8 + ldr q18, [x24, #0x60] + KAI_ASM_INST(0x4e90973b) // sdot v27.4s, v25.16b, v16.16b + KAI_ASM_INST(0x4e90963a) // sdot v26.4s, v17.16b, v16.16b + ldr q17, [x24, #0x70] + add x24, x24, #0x80 + ld1r { v16.2d }, [x21], #0x8 + KAI_ASM_INST(0x4e95971b) // sdot v27.4s, v24.16b, v21.16b + KAI_ASM_INST(0x4e9596fa) // sdot v26.4s, v23.16b, v21.16b + KAI_ASM_INST(0x4e9396db) // sdot v27.4s, v22.16b, v19.16b + KAI_ASM_INST(0x4e93969a) // sdot v26.4s, v20.16b, v19.16b + KAI_ASM_INST(0x4e90965b) // sdot v27.4s, v18.16b, v16.16b + KAI_ASM_INST(0x4e90963a) // sdot v26.4s, v17.16b, v16.16b + bgt label_3 + ldr q22, [x24, #0x0] + ld1r { v21.4s }, [x21] + addp v27.4s, v27.4s, v26.4s + add x21, x21, #0x4 + ld1r { v20.4s }, [x21] + ldr q16, [x24, #0x10] + add x20, x26, #0x4 + cmp x23, #0x4 + ldr q19, [x24, #0x20] + ld1r { v18.4s }, [x26] + add x24, x24, #0x30 + ld1r { v17.4s }, [x20] + mla v27.4s, v22.4s, v21.s[0] + fmul v16.4s, v16.4s, v20.4s + scvtf v27.4s, v27.4s + fmul v16.4s, v27.4s, v16.4s + fadd v16.4s, v16.4s, v19.4s + fmax v16.4s, v16.4s, v18.4s + fmin v16.4s, v16.4s, v17.4s + fcvtn v16.4h, v16.4s + blt label_4 + str d16, [x28, #0x0] + b label_7 +KAI_ASM_LABEL(label_4) // Partial output + mov x20, x28 + tbz x23, #1, label_5 + st1 { v16.s }[0], [x20], #0x4 + tbz x23, #0, label_6 + st1 { v16.h }[2], [x20] + b label_6 +KAI_ASM_LABEL(label_5) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_6) // Output block 0: Done +KAI_ASM_LABEL(label_7) // Stores done + subs x23, x23, #0x4 + add x28, x28, #0x8 + bgt label_2 + subs x25, x25, #0x1 + add x11, x11, x13 + mov x28, x22 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c new file mode 100644 index 00000000..f81e7d7b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c @@ -0,0 +1,163 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +static const size_t kai_num_bytes_qvalue_rhs = 1; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal * kai_num_bytes_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + + if (m == 0) { + return; + } + const size_t num_blocks = kai_get_k_roundedup(k) / kai_k_multiple_of; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h new file mode 100644 index 00000000..eb533351 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h @@ -0,0 +1,138 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S new file mode 100644 index 00000000..ecba5573 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod_asm.S @@ -0,0 +1,713 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_14 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x23, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 + movi v27.4s, #0x0 + movi v26.4s, #0x0 + add x22, x27, x6 + add x21, x22, x6 + add x20, x21, x6 + movi v25.4s, #0x0 + movi v24.4s, #0x0 + movi v23.4s, #0x0 + movi v22.4s, #0x0 + movi v21.4s, #0x0 + movi v20.4s, #0x0 + movi v19.4s, #0x0 + movi v18.4s, #0x0 + movi v17.4s, #0x0 + movi v16.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q15, [x11, #0x0] + ldr q7, [x27, #0x0] + subs x23, x23, #0x1 + ldr q5, [x22, #0x0] + ldr q6, [x21, #0x0] + ldr q4, [x20, #0x0] + ldr q14, [x11, #0x10] + ldr q3, [x27, #0x10] + ldr q2, [x22, #0x10] + KAI_ASM_INST(0x4f87e1ff) // sdot v31.4s, v15.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e1fe) // sdot v30.4s, v15.16b, v7.4b[1] + ldr q1, [x21, #0x10] + ldr q0, [x20, #0x10] + KAI_ASM_INST(0x4f87e9fd) // sdot v29.4s, v15.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e9fc) // sdot v28.4s, v15.16b, v7.4b[3] + ldr q10, [x11, #0x20] + ldr q13, [x27, #0x20] + KAI_ASM_INST(0x4f85e1fb) // sdot v27.4s, v15.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e1fa) // sdot v26.4s, v15.16b, v5.4b[1] + ldr q12, [x22, #0x20] + ldr q11, [x21, #0x20] + KAI_ASM_INST(0x4f85e9f9) // sdot v25.4s, v15.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e9f8) // sdot v24.4s, v15.16b, v5.4b[3] + ldr q9, [x20, #0x20] + ldr q5, [x11, #0x30] + KAI_ASM_INST(0x4f86e1f7) // sdot v23.4s, v15.16b, v6.4b[0] + KAI_ASM_INST(0x4fa6e1f6) // sdot v22.4s, v15.16b, v6.4b[1] + ldr q8, [x27, #0x30] + ldr q7, [x22, #0x30] + KAI_ASM_INST(0x4f86e9f5) // sdot v21.4s, v15.16b, v6.4b[2] + KAI_ASM_INST(0x4fa6e9f4) // sdot v20.4s, v15.16b, v6.4b[3] + ldr q6, [x21, #0x30] + KAI_ASM_INST(0x4f84e1f3) // sdot v19.4s, v15.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e1f2) // sdot v18.4s, v15.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e9f1) // sdot v17.4s, v15.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e9f0) // sdot v16.4s, v15.16b, v4.4b[3] + ldr q4, [x20, #0x30] + ldr q15, [x11, #0x40] + KAI_ASM_INST(0x4f83e1df) // sdot v31.4s, v14.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e1de) // sdot v30.4s, v14.16b, v3.4b[1] + KAI_ASM_INST(0x4f83e9dd) // sdot v29.4s, v14.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3e9dc) // sdot v28.4s, v14.16b, v3.4b[3] + ldr q3, [x27, #0x40] + KAI_ASM_INST(0x4f82e1db) // sdot v27.4s, v14.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e1da) // sdot v26.4s, v14.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e9d9) // sdot v25.4s, v14.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e9d8) // sdot v24.4s, v14.16b, v2.4b[3] + ldr q2, [x22, #0x40] + KAI_ASM_INST(0x4f81e1d7) // sdot v23.4s, v14.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e1d6) // sdot v22.4s, v14.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e9d5) // sdot v21.4s, v14.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e9d4) // sdot v20.4s, v14.16b, v1.4b[3] + ldr q1, [x21, #0x40] + KAI_ASM_INST(0x4f80e1d3) // sdot v19.4s, v14.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e1d2) // sdot v18.4s, v14.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e9d1) // sdot v17.4s, v14.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e9d0) // sdot v16.4s, v14.16b, v0.4b[3] + ldr q0, [x20, #0x40] + ldr q14, [x11, #0x50] + KAI_ASM_INST(0x4f8de15f) // sdot v31.4s, v10.16b, v13.4b[0] + KAI_ASM_INST(0x4fade15e) // sdot v30.4s, v10.16b, v13.4b[1] + KAI_ASM_INST(0x4f8de95d) // sdot v29.4s, v10.16b, v13.4b[2] + KAI_ASM_INST(0x4fade95c) // sdot v28.4s, v10.16b, v13.4b[3] + ldr q13, [x27, #0x50] + KAI_ASM_INST(0x4f8ce15b) // sdot v27.4s, v10.16b, v12.4b[0] + KAI_ASM_INST(0x4face15a) // sdot v26.4s, v10.16b, v12.4b[1] + KAI_ASM_INST(0x4f8ce959) // sdot v25.4s, v10.16b, v12.4b[2] + KAI_ASM_INST(0x4face958) // sdot v24.4s, v10.16b, v12.4b[3] + ldr q12, [x22, #0x50] + KAI_ASM_INST(0x4f8be157) // sdot v23.4s, v10.16b, v11.4b[0] + KAI_ASM_INST(0x4fabe156) // sdot v22.4s, v10.16b, v11.4b[1] + KAI_ASM_INST(0x4f8be955) // sdot v21.4s, v10.16b, v11.4b[2] + KAI_ASM_INST(0x4fabe954) // sdot v20.4s, v10.16b, v11.4b[3] + ldr q11, [x21, #0x50] + KAI_ASM_INST(0x4f89e153) // sdot v19.4s, v10.16b, v9.4b[0] + KAI_ASM_INST(0x4fa9e152) // sdot v18.4s, v10.16b, v9.4b[1] + KAI_ASM_INST(0x4f89e951) // sdot v17.4s, v10.16b, v9.4b[2] + KAI_ASM_INST(0x4fa9e950) // sdot v16.4s, v10.16b, v9.4b[3] + ldr q10, [x20, #0x50] + ldr q9, [x11, #0x60] + KAI_ASM_INST(0x4f88e0bf) // sdot v31.4s, v5.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e0be) // sdot v30.4s, v5.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e8bd) // sdot v29.4s, v5.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e8bc) // sdot v28.4s, v5.16b, v8.4b[3] + ldr q8, [x27, #0x60] + KAI_ASM_INST(0x4f87e0bb) // sdot v27.4s, v5.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e0ba) // sdot v26.4s, v5.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e8b9) // sdot v25.4s, v5.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e8b8) // sdot v24.4s, v5.16b, v7.4b[3] + ldr q7, [x22, #0x60] + KAI_ASM_INST(0x4f86e0b7) // sdot v23.4s, v5.16b, v6.4b[0] + KAI_ASM_INST(0x4fa6e0b6) // sdot v22.4s, v5.16b, v6.4b[1] + KAI_ASM_INST(0x4f86e8b5) // sdot v21.4s, v5.16b, v6.4b[2] + KAI_ASM_INST(0x4fa6e8b4) // sdot v20.4s, v5.16b, v6.4b[3] + ldr q6, [x21, #0x60] + KAI_ASM_INST(0x4f84e0b3) // sdot v19.4s, v5.16b, v4.4b[0] + KAI_ASM_INST(0x4fa4e0b2) // sdot v18.4s, v5.16b, v4.4b[1] + KAI_ASM_INST(0x4f84e8b1) // sdot v17.4s, v5.16b, v4.4b[2] + KAI_ASM_INST(0x4fa4e8b0) // sdot v16.4s, v5.16b, v4.4b[3] + ldr q5, [x20, #0x60] + ldr q4, [x11, #0x70] + KAI_ASM_INST(0x4f83e1ff) // sdot v31.4s, v15.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e1fe) // sdot v30.4s, v15.16b, v3.4b[1] + add x11, x11, #0x80 + KAI_ASM_INST(0x4f83e9fd) // sdot v29.4s, v15.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3e9fc) // sdot v28.4s, v15.16b, v3.4b[3] + ldr q3, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4f82e1fb) // sdot v27.4s, v15.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e1fa) // sdot v26.4s, v15.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e9f9) // sdot v25.4s, v15.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e9f8) // sdot v24.4s, v15.16b, v2.4b[3] + ldr q2, [x22, #0x70] + add x22, x22, #0x80 + KAI_ASM_INST(0x4f81e1f7) // sdot v23.4s, v15.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e1f6) // sdot v22.4s, v15.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e9f5) // sdot v21.4s, v15.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e9f4) // sdot v20.4s, v15.16b, v1.4b[3] + ldr q1, [x21, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4f80e1f3) // sdot v19.4s, v15.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e1f2) // sdot v18.4s, v15.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e9f1) // sdot v17.4s, v15.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e9f0) // sdot v16.4s, v15.16b, v0.4b[3] + ldr q0, [x20, #0x70] + add x20, x20, #0x80 + KAI_ASM_INST(0x4f8de1df) // sdot v31.4s, v14.16b, v13.4b[0] + KAI_ASM_INST(0x4fade1de) // sdot v30.4s, v14.16b, v13.4b[1] + KAI_ASM_INST(0x4f8de9dd) // sdot v29.4s, v14.16b, v13.4b[2] + KAI_ASM_INST(0x4fade9dc) // sdot v28.4s, v14.16b, v13.4b[3] + KAI_ASM_INST(0x4f8ce1db) // sdot v27.4s, v14.16b, v12.4b[0] + KAI_ASM_INST(0x4face1da) // sdot v26.4s, v14.16b, v12.4b[1] + KAI_ASM_INST(0x4f8ce9d9) // sdot v25.4s, v14.16b, v12.4b[2] + KAI_ASM_INST(0x4face9d8) // sdot v24.4s, v14.16b, v12.4b[3] + KAI_ASM_INST(0x4f8be1d7) // sdot v23.4s, v14.16b, v11.4b[0] + KAI_ASM_INST(0x4fabe1d6) // sdot v22.4s, v14.16b, v11.4b[1] + KAI_ASM_INST(0x4f8be9d5) // sdot v21.4s, v14.16b, v11.4b[2] + KAI_ASM_INST(0x4fabe9d4) // sdot v20.4s, v14.16b, v11.4b[3] + KAI_ASM_INST(0x4f8ae1d3) // sdot v19.4s, v14.16b, v10.4b[0] + KAI_ASM_INST(0x4faae1d2) // sdot v18.4s, v14.16b, v10.4b[1] + KAI_ASM_INST(0x4f8ae9d1) // sdot v17.4s, v14.16b, v10.4b[2] + KAI_ASM_INST(0x4faae9d0) // sdot v16.4s, v14.16b, v10.4b[3] + KAI_ASM_INST(0x4f88e13f) // sdot v31.4s, v9.16b, v8.4b[0] + KAI_ASM_INST(0x4fa8e13e) // sdot v30.4s, v9.16b, v8.4b[1] + KAI_ASM_INST(0x4f88e93d) // sdot v29.4s, v9.16b, v8.4b[2] + KAI_ASM_INST(0x4fa8e93c) // sdot v28.4s, v9.16b, v8.4b[3] + KAI_ASM_INST(0x4f87e13b) // sdot v27.4s, v9.16b, v7.4b[0] + KAI_ASM_INST(0x4fa7e13a) // sdot v26.4s, v9.16b, v7.4b[1] + KAI_ASM_INST(0x4f87e939) // sdot v25.4s, v9.16b, v7.4b[2] + KAI_ASM_INST(0x4fa7e938) // sdot v24.4s, v9.16b, v7.4b[3] + KAI_ASM_INST(0x4f86e137) // sdot v23.4s, v9.16b, v6.4b[0] + KAI_ASM_INST(0x4fa6e136) // sdot v22.4s, v9.16b, v6.4b[1] + KAI_ASM_INST(0x4f86e935) // sdot v21.4s, v9.16b, v6.4b[2] + KAI_ASM_INST(0x4fa6e934) // sdot v20.4s, v9.16b, v6.4b[3] + KAI_ASM_INST(0x4f85e133) // sdot v19.4s, v9.16b, v5.4b[0] + KAI_ASM_INST(0x4fa5e132) // sdot v18.4s, v9.16b, v5.4b[1] + KAI_ASM_INST(0x4f85e931) // sdot v17.4s, v9.16b, v5.4b[2] + KAI_ASM_INST(0x4fa5e930) // sdot v16.4s, v9.16b, v5.4b[3] + KAI_ASM_INST(0x4f83e09f) // sdot v31.4s, v4.16b, v3.4b[0] + KAI_ASM_INST(0x4fa3e09e) // sdot v30.4s, v4.16b, v3.4b[1] + KAI_ASM_INST(0x4f83e89d) // sdot v29.4s, v4.16b, v3.4b[2] + KAI_ASM_INST(0x4fa3e89c) // sdot v28.4s, v4.16b, v3.4b[3] + KAI_ASM_INST(0x4f82e09b) // sdot v27.4s, v4.16b, v2.4b[0] + KAI_ASM_INST(0x4fa2e09a) // sdot v26.4s, v4.16b, v2.4b[1] + KAI_ASM_INST(0x4f82e899) // sdot v25.4s, v4.16b, v2.4b[2] + KAI_ASM_INST(0x4fa2e898) // sdot v24.4s, v4.16b, v2.4b[3] + KAI_ASM_INST(0x4f81e097) // sdot v23.4s, v4.16b, v1.4b[0] + KAI_ASM_INST(0x4fa1e096) // sdot v22.4s, v4.16b, v1.4b[1] + KAI_ASM_INST(0x4f81e895) // sdot v21.4s, v4.16b, v1.4b[2] + KAI_ASM_INST(0x4fa1e894) // sdot v20.4s, v4.16b, v1.4b[3] + KAI_ASM_INST(0x4f80e093) // sdot v19.4s, v4.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e092) // sdot v18.4s, v4.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e891) // sdot v17.4s, v4.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e890) // sdot v16.4s, v4.16b, v0.4b[3] + bgt label_3 + ldr q5, [x11, #0x0] + ld1 { v1.4s }, [x27] + add x27, x27, #0x10 + ldr q4, [x11, #0x10] + ldr q0, [x27, #0x0] + add x11, x11, #0x20 + mla v31.4s, v5.4s, v1.s[0] + mla v30.4s, v5.4s, v1.s[1] + mla v29.4s, v5.4s, v1.s[2] + mla v28.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v31.4s, v31.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v30.4s, v30.4s + scvtf v29.4s, v29.4s + scvtf v28.4s, v28.4s + fmul v31.4s, v31.4s, v3.4s + fmul v30.4s, v30.4s, v2.4s + fmul v29.4s, v29.4s, v1.4s + fmul v28.4s, v28.4s, v0.4s + ld1 { v1.4s }, [x22] + add x22, x22, #0x10 + ldr q0, [x22, #0x0] + mla v27.4s, v5.4s, v1.s[0] + mla v26.4s, v5.4s, v1.s[1] + mla v25.4s, v5.4s, v1.s[2] + mla v24.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v27.4s, v27.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v26.4s, v26.4s + scvtf v25.4s, v25.4s + scvtf v24.4s, v24.4s + fmul v27.4s, v27.4s, v3.4s + fmul v26.4s, v26.4s, v2.4s + fmul v25.4s, v25.4s, v1.4s + fmul v24.4s, v24.4s, v0.4s + ld1 { v1.4s }, [x21] + add x21, x21, #0x10 + ldr q0, [x21, #0x0] + mla v23.4s, v5.4s, v1.s[0] + mla v22.4s, v5.4s, v1.s[1] + mla v21.4s, v5.4s, v1.s[2] + mla v20.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v23.4s, v23.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v22.4s, v22.4s + scvtf v21.4s, v21.4s + scvtf v20.4s, v20.4s + fmul v23.4s, v23.4s, v3.4s + fmul v22.4s, v22.4s, v2.4s + fmul v21.4s, v21.4s, v1.4s + fmul v20.4s, v20.4s, v0.4s + ld1 { v1.4s }, [x20] + add x20, x20, #0x10 + ldr q0, [x20, #0x0] + mla v19.4s, v5.4s, v1.s[0] + mla v18.4s, v5.4s, v1.s[1] + mla v17.4s, v5.4s, v1.s[2] + mla v16.4s, v5.4s, v1.s[3] + fmul v3.4s, v4.4s, v0.s[0] + fmul v2.4s, v4.4s, v0.s[1] + fmul v1.4s, v4.4s, v0.s[2] + scvtf v19.4s, v19.4s + fmul v0.4s, v4.4s, v0.s[3] + scvtf v18.4s, v18.4s + scvtf v17.4s, v17.4s + scvtf v16.4s, v16.4s + fmul v19.4s, v19.4s, v3.4s + fmul v18.4s, v18.4s, v2.4s + fmul v17.4s, v17.4s, v1.4s + fmul v16.4s, v16.4s, v0.4s + ldr q2, [x11, #0x0] + ld1r { v1.4s }, [x12] + add x20, x12, #0x4 + cmp x10, #0x4 + ld1r { v0.4s }, [x20] + add x11, x11, #0x10 + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + fcvtn v31.4h, v31.4s + fcvtn v30.4h, v30.4s + fcvtn v29.4h, v29.4s + fcvtn v28.4h, v28.4s + fcvtn v27.4h, v27.4s + fcvtn v26.4h, v26.4s + fcvtn v25.4h, v25.4s + fcvtn v24.4h, v24.4s + fcvtn v23.4h, v23.4s + fcvtn v22.4h, v22.4s + fcvtn v21.4h, v21.4s + fcvtn v20.4h, v20.4s + fcvtn v19.4h, v19.4s + fcvtn v18.4h, v18.4s + fcvtn v17.4h, v17.4s + fcvtn v16.4h, v16.4s + blt label_8 + mov x20, x15 + str d31, [x20, #0x0] + add x20, x20, x13 + str d30, [x20, #0x0] + add x20, x20, x13 + str d29, [x20, #0x0] + add x20, x20, x13 + str d28, [x20, #0x0] + add x20, x20, x13 + str d27, [x20, #0x0] + add x20, x20, x13 + str d26, [x20, #0x0] + add x20, x20, x13 + str d25, [x20, #0x0] + add x20, x20, x13 + str d24, [x20, #0x0] + add x20, x20, x13 + str d23, [x20, #0x0] + add x20, x20, x13 + str d22, [x20, #0x0] + add x20, x20, x13 + str d21, [x20, #0x0] + add x20, x20, x13 + str d20, [x20, #0x0] + add x20, x20, x13 + str d19, [x20, #0x0] + add x20, x20, x13 + str d18, [x20, #0x0] + add x20, x20, x13 + str d17, [x20, #0x0] + add x20, x20, x13 + str d16, [x20, #0x0] + b label_13 +KAI_ASM_LABEL(label_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_9 + st1 { v24.s }[0], [x23], #0x4 + st1 { v25.s }[0], [x25], #0x4 + st1 { v26.s }[0], [x24], #0x4 + st1 { v27.s }[0], [x26], #0x4 + st1 { v28.s }[0], [x20], #0x4 + st1 { v29.s }[0], [x22], #0x4 + st1 { v30.s }[0], [x21], #0x4 + st1 { v31.s }[0], [x28], #0x4 + tbz x10, #0, label_10 + st1 { v24.h }[2], [x23] + st1 { v25.h }[2], [x25] + st1 { v26.h }[2], [x24] + st1 { v27.h }[2], [x26] + st1 { v28.h }[2], [x20] + st1 { v29.h }[2], [x22] + st1 { v30.h }[2], [x21] + st1 { v31.h }[2], [x28] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v24.h }[0], [x23] + st1 { v25.h }[0], [x25] + st1 { v26.h }[0], [x24] + st1 { v27.h }[0], [x26] + st1 { v28.h }[0], [x20] + st1 { v29.h }[0], [x22] + st1 { v30.h }[0], [x21] + st1 { v31.h }[0], [x28] +KAI_ASM_LABEL(label_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_11 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x24], #0x4 + st1 { v18.s }[0], [x21], #0x4 + st1 { v19.s }[0], [x26], #0x4 + st1 { v20.s }[0], [x22], #0x4 + st1 { v21.s }[0], [x25], #0x4 + st1 { v22.s }[0], [x23], #0x4 + st1 { v23.s }[0], [x27], #0x4 + tbz x10, #0, label_12 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x24] + st1 { v18.h }[2], [x21] + st1 { v19.h }[2], [x26] + st1 { v20.h }[2], [x22] + st1 { v21.h }[2], [x25] + st1 { v22.h }[2], [x23] + st1 { v23.h }[2], [x27] + b label_12 +KAI_ASM_LABEL(label_11) // Output block 1: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x24] + st1 { v18.h }[0], [x21] + st1 { v19.h }[0], [x26] + st1 { v20.h }[0], [x22] + st1 { v21.h }[0], [x25] + st1 { v22.h }[0], [x23] + st1 { v23.h }[0], [x27] +KAI_ASM_LABEL(label_12) // Output block 1: Done +KAI_ASM_LABEL(label_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x8 + bgt label_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_1 +KAI_ASM_LABEL(label_14) // Row loop skip + cbz x14, label_23 +KAI_ASM_LABEL(label_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_16) // Row tail: Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x20, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 +KAI_ASM_LABEL(label_17) // Row tail: Sub block loop + ldr q17, [x26, #0x0] + ldr q16, [x27, #0x0] + subs x20, x20, #0x1 + ldr q1, [x26, #0x10] + ldr q0, [x27, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x27, #0x20] + ldr q25, [x26, #0x30] + ldr q24, [x27, #0x30] + KAI_ASM_INST(0x4f90e23f) // sdot v31.4s, v17.16b, v16.4b[0] + KAI_ASM_INST(0x4fb0e23e) // sdot v30.4s, v17.16b, v16.4b[1] + ldr q23, [x26, #0x40] + ldr q22, [x27, #0x40] + KAI_ASM_INST(0x4f90ea3d) // sdot v29.4s, v17.16b, v16.4b[2] + KAI_ASM_INST(0x4fb0ea3c) // sdot v28.4s, v17.16b, v16.4b[3] + ldr q21, [x26, #0x50] + ldr q20, [x27, #0x50] + ldr q19, [x26, #0x60] + ldr q18, [x27, #0x60] + ldr q17, [x26, #0x70] + ldr q16, [x27, #0x70] + KAI_ASM_INST(0x4f80e03f) // sdot v31.4s, v1.16b, v0.4b[0] + KAI_ASM_INST(0x4fa0e03e) // sdot v30.4s, v1.16b, v0.4b[1] + KAI_ASM_INST(0x4f80e83d) // sdot v29.4s, v1.16b, v0.4b[2] + KAI_ASM_INST(0x4fa0e83c) // sdot v28.4s, v1.16b, v0.4b[3] + add x27, x27, #0x80 + add x26, x26, #0x80 + KAI_ASM_INST(0x4f9ae37f) // sdot v31.4s, v27.16b, v26.4b[0] + KAI_ASM_INST(0x4fbae37e) // sdot v30.4s, v27.16b, v26.4b[1] + KAI_ASM_INST(0x4f9aeb7d) // sdot v29.4s, v27.16b, v26.4b[2] + KAI_ASM_INST(0x4fbaeb7c) // sdot v28.4s, v27.16b, v26.4b[3] + KAI_ASM_INST(0x4f98e33f) // sdot v31.4s, v25.16b, v24.4b[0] + KAI_ASM_INST(0x4fb8e33e) // sdot v30.4s, v25.16b, v24.4b[1] + KAI_ASM_INST(0x4f98eb3d) // sdot v29.4s, v25.16b, v24.4b[2] + KAI_ASM_INST(0x4fb8eb3c) // sdot v28.4s, v25.16b, v24.4b[3] + KAI_ASM_INST(0x4f96e2ff) // sdot v31.4s, v23.16b, v22.4b[0] + KAI_ASM_INST(0x4fb6e2fe) // sdot v30.4s, v23.16b, v22.4b[1] + KAI_ASM_INST(0x4f96eafd) // sdot v29.4s, v23.16b, v22.4b[2] + KAI_ASM_INST(0x4fb6eafc) // sdot v28.4s, v23.16b, v22.4b[3] + KAI_ASM_INST(0x4f94e2bf) // sdot v31.4s, v21.16b, v20.4b[0] + KAI_ASM_INST(0x4fb4e2be) // sdot v30.4s, v21.16b, v20.4b[1] + KAI_ASM_INST(0x4f94eabd) // sdot v29.4s, v21.16b, v20.4b[2] + KAI_ASM_INST(0x4fb4eabc) // sdot v28.4s, v21.16b, v20.4b[3] + KAI_ASM_INST(0x4f92e27f) // sdot v31.4s, v19.16b, v18.4b[0] + KAI_ASM_INST(0x4fb2e27e) // sdot v30.4s, v19.16b, v18.4b[1] + KAI_ASM_INST(0x4f92ea7d) // sdot v29.4s, v19.16b, v18.4b[2] + KAI_ASM_INST(0x4fb2ea7c) // sdot v28.4s, v19.16b, v18.4b[3] + KAI_ASM_INST(0x4f90e23f) // sdot v31.4s, v17.16b, v16.4b[0] + KAI_ASM_INST(0x4fb0e23e) // sdot v30.4s, v17.16b, v16.4b[1] + KAI_ASM_INST(0x4f90ea3d) // sdot v29.4s, v17.16b, v16.4b[2] + KAI_ASM_INST(0x4fb0ea3c) // sdot v28.4s, v17.16b, v16.4b[3] + bgt label_17 + ldr q18, [x26, #0x0] + ld1 { v17.4s }, [x27] + add x27, x27, #0x10 + ldr q20, [x26, #0x10] + ldr q16, [x27, #0x0] + add x26, x26, #0x20 + mla v31.4s, v18.4s, v17.s[0] + mla v30.4s, v18.4s, v17.s[1] + mla v29.4s, v18.4s, v17.s[2] + mla v28.4s, v18.4s, v17.s[3] + fmul v19.4s, v20.4s, v16.s[0] + fmul v18.4s, v20.4s, v16.s[1] + fmul v17.4s, v20.4s, v16.s[2] + scvtf v31.4s, v31.4s + fmul v16.4s, v20.4s, v16.s[3] + scvtf v30.4s, v30.4s + scvtf v29.4s, v29.4s + scvtf v28.4s, v28.4s + fmul v31.4s, v31.4s, v19.4s + fmul v30.4s, v30.4s, v18.4s + fmul v29.4s, v29.4s, v17.4s + fmul v28.4s, v28.4s, v16.4s + ldr q18, [x26, #0x0] + ld1r { v17.4s }, [x12] + add x20, x12, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v31.4s, v31.4s, v18.4s + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + fcvtn v19.4h, v31.4s + fcvtn v18.4h, v30.4s + fcvtn v17.4h, v29.4s + fcvtn v16.4h, v28.4s + blt label_19 + mov x20, x15 + cmp x14, #0x1 + str d19, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x2 + str d18, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x3 + str d17, [x20, #0x0] + add x20, x20, x13 + ble label_22 + str d16, [x20, #0x0] + b label_22 +KAI_ASM_LABEL(label_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_20 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x21], #0x4 + st1 { v18.s }[0], [x22], #0x4 + st1 { v19.s }[0], [x23], #0x4 + tbz x25, #0, label_21 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x21] + st1 { v18.h }[2], [x22] + st1 { v19.h }[2], [x23] + b label_21 +KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x21] + st1 { v18.h }[0], [x22] + st1 { v19.h }[0], [x23] +KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x8 + bgt label_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_15 +KAI_ASM_LABEL(label_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c new file mode 100644 index 00000000..49d6ba4e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c @@ -0,0 +1,163 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#error "I8mm extension and fp16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 1; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +static const size_t kai_num_bytes_qvalue_rhs = 1; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal * kai_num_bytes_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + + if (m == 0) { + return; + } + const size_t num_blocks = kai_get_k_roundedup(k) / kai_k_multiple_of; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h new file mode 100644 index 00000000..31d3a8c2 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h @@ -0,0 +1,138 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_f16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_asm.S new file mode 100644 index 00000000..c1a22c62 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm_asm.S @@ -0,0 +1,653 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x6, #0x80 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + ldr x15, [x0, #0x0] + mov x14, x20 + ldr x13, [x0, #0x20] + madd x6, x7, x6, x21 + ldr x12, [x0, #0x18] + cmp x14, #0x10 + blt label_14 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x15, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x23, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 + movi v27.4s, #0x0 + movi v26.4s, #0x0 + add x22, x27, x6 + add x21, x22, x6 + add x20, x21, x6 + movi v25.4s, #0x0 + movi v24.4s, #0x0 + movi v23.4s, #0x0 + movi v22.4s, #0x0 + movi v21.4s, #0x0 + movi v20.4s, #0x0 + movi v19.4s, #0x0 + movi v18.4s, #0x0 + movi v17.4s, #0x0 + movi v16.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q2, [x11, #0x0] + ldr q1, [x11, #0x10] + subs x23, x23, #0x1 + ldr q5, [x27, #0x0] + ldr q9, [x27, #0x10] + ldr q8, [x22, #0x0] + ldr q7, [x22, #0x10] + ldr q4, [x21, #0x0] + ldr q14, [x21, #0x10] + ldr q3, [x20, #0x0] + ldr q0, [x20, #0x10] + KAI_ASM_INST(0x4e82a4bf) // smmla v31.4s, v5.16b, v2.16b + KAI_ASM_INST(0x4e81a4be) // smmla v30.4s, v5.16b, v1.16b + ldr q6, [x11, #0x20] + ldr q5, [x11, #0x30] + KAI_ASM_INST(0x4e82a53d) // smmla v29.4s, v9.16b, v2.16b + KAI_ASM_INST(0x4e81a53c) // smmla v28.4s, v9.16b, v1.16b + ldr q13, [x27, #0x20] + ldr q12, [x27, #0x30] + KAI_ASM_INST(0x4e82a51b) // smmla v27.4s, v8.16b, v2.16b + KAI_ASM_INST(0x4e81a51a) // smmla v26.4s, v8.16b, v1.16b + ldr q11, [x22, #0x20] + ldr q10, [x22, #0x30] + KAI_ASM_INST(0x4e82a4f9) // smmla v25.4s, v7.16b, v2.16b + KAI_ASM_INST(0x4e81a4f8) // smmla v24.4s, v7.16b, v1.16b + ldr q9, [x21, #0x20] + ldr q8, [x21, #0x30] + KAI_ASM_INST(0x4e82a497) // smmla v23.4s, v4.16b, v2.16b + KAI_ASM_INST(0x4e81a496) // smmla v22.4s, v4.16b, v1.16b + ldr q7, [x20, #0x20] + ldr q4, [x20, #0x30] + KAI_ASM_INST(0x4e82a5d5) // smmla v21.4s, v14.16b, v2.16b + KAI_ASM_INST(0x4e81a5d4) // smmla v20.4s, v14.16b, v1.16b + ldr q15, [x11, #0x40] + ldr q14, [x11, #0x50] + KAI_ASM_INST(0x4e82a473) // smmla v19.4s, v3.16b, v2.16b + KAI_ASM_INST(0x4e81a472) // smmla v18.4s, v3.16b, v1.16b + ldr q3, [x27, #0x40] + KAI_ASM_INST(0x4e82a411) // smmla v17.4s, v0.16b, v2.16b + ldr q2, [x27, #0x50] + KAI_ASM_INST(0x4e81a410) // smmla v16.4s, v0.16b, v1.16b + ldr q1, [x22, #0x40] + ldr q0, [x22, #0x50] + KAI_ASM_INST(0x4e86a5bf) // smmla v31.4s, v13.16b, v6.16b + KAI_ASM_INST(0x4e85a5be) // smmla v30.4s, v13.16b, v5.16b + ldr q13, [x21, #0x40] + KAI_ASM_INST(0x4e86a59d) // smmla v29.4s, v12.16b, v6.16b + KAI_ASM_INST(0x4e85a59c) // smmla v28.4s, v12.16b, v5.16b + ldr q12, [x21, #0x50] + KAI_ASM_INST(0x4e86a57b) // smmla v27.4s, v11.16b, v6.16b + KAI_ASM_INST(0x4e85a57a) // smmla v26.4s, v11.16b, v5.16b + ldr q11, [x20, #0x40] + KAI_ASM_INST(0x4e86a559) // smmla v25.4s, v10.16b, v6.16b + KAI_ASM_INST(0x4e85a558) // smmla v24.4s, v10.16b, v5.16b + ldr q10, [x20, #0x50] + KAI_ASM_INST(0x4e86a537) // smmla v23.4s, v9.16b, v6.16b + KAI_ASM_INST(0x4e85a536) // smmla v22.4s, v9.16b, v5.16b + ldr q9, [x11, #0x60] + KAI_ASM_INST(0x4e86a515) // smmla v21.4s, v8.16b, v6.16b + KAI_ASM_INST(0x4e85a514) // smmla v20.4s, v8.16b, v5.16b + ldr q8, [x11, #0x70] + add x11, x11, #0x80 + KAI_ASM_INST(0x4e86a4f3) // smmla v19.4s, v7.16b, v6.16b + KAI_ASM_INST(0x4e85a4f2) // smmla v18.4s, v7.16b, v5.16b + ldr q7, [x27, #0x60] + KAI_ASM_INST(0x4e86a491) // smmla v17.4s, v4.16b, v6.16b + ldr q6, [x27, #0x70] + KAI_ASM_INST(0x4e85a490) // smmla v16.4s, v4.16b, v5.16b + ldr q5, [x22, #0x60] + ldr q4, [x22, #0x70] + KAI_ASM_INST(0x4e8fa47f) // smmla v31.4s, v3.16b, v15.16b + KAI_ASM_INST(0x4e8ea47e) // smmla v30.4s, v3.16b, v14.16b + ldr q3, [x21, #0x60] + KAI_ASM_INST(0x4e8fa45d) // smmla v29.4s, v2.16b, v15.16b + KAI_ASM_INST(0x4e8ea45c) // smmla v28.4s, v2.16b, v14.16b + ldr q2, [x21, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4e8fa43b) // smmla v27.4s, v1.16b, v15.16b + KAI_ASM_INST(0x4e8ea43a) // smmla v26.4s, v1.16b, v14.16b + ldr q1, [x20, #0x60] + add x22, x22, #0x80 + KAI_ASM_INST(0x4e8fa419) // smmla v25.4s, v0.16b, v15.16b + KAI_ASM_INST(0x4e8ea418) // smmla v24.4s, v0.16b, v14.16b + ldr q0, [x20, #0x70] + add x21, x21, #0x80 + KAI_ASM_INST(0x4e8fa5b7) // smmla v23.4s, v13.16b, v15.16b + KAI_ASM_INST(0x4e8ea5b6) // smmla v22.4s, v13.16b, v14.16b + add x20, x20, #0x80 + KAI_ASM_INST(0x4e8fa595) // smmla v21.4s, v12.16b, v15.16b + KAI_ASM_INST(0x4e8ea594) // smmla v20.4s, v12.16b, v14.16b + KAI_ASM_INST(0x4e8fa573) // smmla v19.4s, v11.16b, v15.16b + KAI_ASM_INST(0x4e8ea572) // smmla v18.4s, v11.16b, v14.16b + KAI_ASM_INST(0x4e8fa551) // smmla v17.4s, v10.16b, v15.16b + KAI_ASM_INST(0x4e8ea550) // smmla v16.4s, v10.16b, v14.16b + KAI_ASM_INST(0x4e89a4ff) // smmla v31.4s, v7.16b, v9.16b + KAI_ASM_INST(0x4e88a4fe) // smmla v30.4s, v7.16b, v8.16b + KAI_ASM_INST(0x4e89a4dd) // smmla v29.4s, v6.16b, v9.16b + KAI_ASM_INST(0x4e88a4dc) // smmla v28.4s, v6.16b, v8.16b + KAI_ASM_INST(0x4e89a4bb) // smmla v27.4s, v5.16b, v9.16b + KAI_ASM_INST(0x4e88a4ba) // smmla v26.4s, v5.16b, v8.16b + KAI_ASM_INST(0x4e89a499) // smmla v25.4s, v4.16b, v9.16b + KAI_ASM_INST(0x4e88a498) // smmla v24.4s, v4.16b, v8.16b + KAI_ASM_INST(0x4e89a477) // smmla v23.4s, v3.16b, v9.16b + KAI_ASM_INST(0x4e88a476) // smmla v22.4s, v3.16b, v8.16b + KAI_ASM_INST(0x4e89a455) // smmla v21.4s, v2.16b, v9.16b + KAI_ASM_INST(0x4e88a454) // smmla v20.4s, v2.16b, v8.16b + KAI_ASM_INST(0x4e89a433) // smmla v19.4s, v1.16b, v9.16b + KAI_ASM_INST(0x4e88a432) // smmla v18.4s, v1.16b, v8.16b + KAI_ASM_INST(0x4e89a411) // smmla v17.4s, v0.16b, v9.16b + KAI_ASM_INST(0x4e88a410) // smmla v16.4s, v0.16b, v8.16b + bgt label_3 + ldr q7, [x11, #0x0] + ld1 { v4.4s }, [x27] + uzp1 v3.2d, v31.2d, v30.2d + uzp2 v2.2d, v31.2d, v30.2d + ldr q6, [x11, #0x10] + uzp1 v1.2d, v29.2d, v28.2d + uzp2 v0.2d, v29.2d, v28.2d + add x27, x27, #0x10 + ldr q28, [x27, #0x0] + add x11, x11, #0x20 + mla v3.4s, v7.4s, v4.s[0] + mla v2.4s, v7.4s, v4.s[1] + mla v1.4s, v7.4s, v4.s[2] + mla v0.4s, v7.4s, v4.s[3] + fmul v31.4s, v6.4s, v28.s[0] + fmul v30.4s, v6.4s, v28.s[1] + fmul v29.4s, v6.4s, v28.s[2] + fmul v28.4s, v6.4s, v28.s[3] + scvtf v3.4s, v3.4s + scvtf v2.4s, v2.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v31.4s, v3.4s, v31.4s + fmul v30.4s, v2.4s, v30.4s + fmul v29.4s, v1.4s, v29.4s + fmul v28.4s, v0.4s, v28.4s + ld1 { v5.4s }, [x22] + uzp1 v4.2d, v27.2d, v26.2d + uzp2 v3.2d, v27.2d, v26.2d + add x22, x22, #0x10 + ldr q2, [x22, #0x0] + uzp1 v1.2d, v25.2d, v24.2d + uzp2 v0.2d, v25.2d, v24.2d + mla v4.4s, v7.4s, v5.s[0] + mla v3.4s, v7.4s, v5.s[1] + mla v1.4s, v7.4s, v5.s[2] + mla v0.4s, v7.4s, v5.s[3] + fmul v27.4s, v6.4s, v2.s[0] + fmul v26.4s, v6.4s, v2.s[1] + fmul v25.4s, v6.4s, v2.s[2] + scvtf v4.4s, v4.4s + fmul v24.4s, v6.4s, v2.s[3] + scvtf v3.4s, v3.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v27.4s, v4.4s, v27.4s + fmul v26.4s, v3.4s, v26.4s + fmul v25.4s, v1.4s, v25.4s + fmul v24.4s, v0.4s, v24.4s + ld1 { v5.4s }, [x21] + uzp1 v4.2d, v23.2d, v22.2d + uzp2 v3.2d, v23.2d, v22.2d + add x21, x21, #0x10 + ldr q2, [x21, #0x0] + uzp1 v1.2d, v21.2d, v20.2d + uzp2 v0.2d, v21.2d, v20.2d + mla v4.4s, v7.4s, v5.s[0] + mla v3.4s, v7.4s, v5.s[1] + mla v1.4s, v7.4s, v5.s[2] + mla v0.4s, v7.4s, v5.s[3] + fmul v23.4s, v6.4s, v2.s[0] + fmul v22.4s, v6.4s, v2.s[1] + fmul v21.4s, v6.4s, v2.s[2] + scvtf v4.4s, v4.4s + fmul v20.4s, v6.4s, v2.s[3] + scvtf v3.4s, v3.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v23.4s, v4.4s, v23.4s + fmul v22.4s, v3.4s, v22.4s + fmul v21.4s, v1.4s, v21.4s + fmul v20.4s, v0.4s, v20.4s + ld1 { v5.4s }, [x20] + uzp1 v4.2d, v19.2d, v18.2d + uzp2 v3.2d, v19.2d, v18.2d + add x20, x20, #0x10 + ldr q2, [x20, #0x0] + uzp1 v1.2d, v17.2d, v16.2d + uzp2 v0.2d, v17.2d, v16.2d + mla v4.4s, v7.4s, v5.s[0] + mla v3.4s, v7.4s, v5.s[1] + mla v1.4s, v7.4s, v5.s[2] + mla v0.4s, v7.4s, v5.s[3] + fmul v19.4s, v6.4s, v2.s[0] + fmul v18.4s, v6.4s, v2.s[1] + fmul v17.4s, v6.4s, v2.s[2] + scvtf v4.4s, v4.4s + fmul v16.4s, v6.4s, v2.s[3] + scvtf v3.4s, v3.4s + scvtf v1.4s, v1.4s + scvtf v0.4s, v0.4s + fmul v19.4s, v4.4s, v19.4s + fmul v18.4s, v3.4s, v18.4s + fmul v17.4s, v1.4s, v17.4s + fmul v16.4s, v0.4s, v16.4s + ldr q2, [x11, #0x0] + ld1r { v1.4s }, [x12] + add x20, x12, #0x4 + cmp x10, #0x4 + ld1r { v0.4s }, [x20] + add x11, x11, #0x10 + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + fcvtn v31.4h, v31.4s + fcvtn v30.4h, v30.4s + fcvtn v29.4h, v29.4s + fcvtn v28.4h, v28.4s + fcvtn v27.4h, v27.4s + fcvtn v26.4h, v26.4s + fcvtn v25.4h, v25.4s + fcvtn v24.4h, v24.4s + fcvtn v23.4h, v23.4s + fcvtn v22.4h, v22.4s + fcvtn v21.4h, v21.4s + fcvtn v20.4h, v20.4s + fcvtn v19.4h, v19.4s + fcvtn v18.4h, v18.4s + fcvtn v17.4h, v17.4s + fcvtn v16.4h, v16.4s + blt label_8 + mov x20, x15 + str d31, [x20, #0x0] + add x20, x20, x13 + str d30, [x20, #0x0] + add x20, x20, x13 + str d29, [x20, #0x0] + add x20, x20, x13 + str d28, [x20, #0x0] + add x20, x20, x13 + str d27, [x20, #0x0] + add x20, x20, x13 + str d26, [x20, #0x0] + add x20, x20, x13 + str d25, [x20, #0x0] + add x20, x20, x13 + str d24, [x20, #0x0] + add x20, x20, x13 + str d23, [x20, #0x0] + add x20, x20, x13 + str d22, [x20, #0x0] + add x20, x20, x13 + str d21, [x20, #0x0] + add x20, x20, x13 + str d20, [x20, #0x0] + add x20, x20, x13 + str d19, [x20, #0x0] + add x20, x20, x13 + str d18, [x20, #0x0] + add x20, x20, x13 + str d17, [x20, #0x0] + add x20, x20, x13 + str d16, [x20, #0x0] + b label_13 +KAI_ASM_LABEL(label_8) // Partial output + mov x28, x15 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_9 + st1 { v24.s }[0], [x23], #0x4 + st1 { v25.s }[0], [x25], #0x4 + st1 { v26.s }[0], [x24], #0x4 + st1 { v27.s }[0], [x26], #0x4 + st1 { v28.s }[0], [x20], #0x4 + st1 { v29.s }[0], [x22], #0x4 + st1 { v30.s }[0], [x21], #0x4 + st1 { v31.s }[0], [x28], #0x4 + tbz x10, #0, label_10 + st1 { v24.h }[2], [x23] + st1 { v25.h }[2], [x25] + st1 { v26.h }[2], [x24] + st1 { v27.h }[2], [x26] + st1 { v28.h }[2], [x20] + st1 { v29.h }[2], [x22] + st1 { v30.h }[2], [x21] + st1 { v31.h }[2], [x28] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v24.h }[0], [x23] + st1 { v25.h }[0], [x25] + st1 { v26.h }[0], [x24] + st1 { v27.h }[0], [x26] + st1 { v28.h }[0], [x20] + st1 { v29.h }[0], [x22] + st1 { v30.h }[0], [x21] + st1 { v31.h }[0], [x28] +KAI_ASM_LABEL(label_10) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_11 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x24], #0x4 + st1 { v18.s }[0], [x21], #0x4 + st1 { v19.s }[0], [x26], #0x4 + st1 { v20.s }[0], [x22], #0x4 + st1 { v21.s }[0], [x25], #0x4 + st1 { v22.s }[0], [x23], #0x4 + st1 { v23.s }[0], [x27], #0x4 + tbz x10, #0, label_12 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x24] + st1 { v18.h }[2], [x21] + st1 { v19.h }[2], [x26] + st1 { v20.h }[2], [x22] + st1 { v21.h }[2], [x25] + st1 { v22.h }[2], [x23] + st1 { v23.h }[2], [x27] + b label_12 +KAI_ASM_LABEL(label_11) // Output block 1: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x24] + st1 { v18.h }[0], [x21] + st1 { v19.h }[0], [x26] + st1 { v20.h }[0], [x22] + st1 { v21.h }[0], [x25] + st1 { v22.h }[0], [x23] + st1 { v23.h }[0], [x27] +KAI_ASM_LABEL(label_12) // Output block 1: Done +KAI_ASM_LABEL(label_13) // Output stage exit + subs x10, x10, #0x4 + add x15, x15, #0x8 + bgt label_2 + mov x20, #0x4 + sub x14, x14, #0x10 + cmp x14, #0x10 + mov x15, x9 + madd x8, x20, x6, x8 + bge label_1 +KAI_ASM_LABEL(label_14) // Row loop skip + cbz x14, label_23 +KAI_ASM_LABEL(label_15) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x15, x13, LSL #2 +KAI_ASM_LABEL(label_16) // Row tail: Column loop + mov x27, x8 + movi v31.4s, #0x0 + movi v30.4s, #0x0 + mov x20, x7 + movi v29.4s, #0x0 + movi v28.4s, #0x0 +KAI_ASM_LABEL(label_17) // Row tail: Sub block loop + ldr q19, [x26, #0x0] + ldr q18, [x26, #0x10] + subs x20, x20, #0x1 + ldr q17, [x27, #0x0] + ldr q16, [x27, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x26, #0x30] + ldr q25, [x27, #0x20] + ldr q24, [x27, #0x30] + ldr q23, [x26, #0x40] + ldr q22, [x26, #0x50] + KAI_ASM_INST(0x4e93a63f) // smmla v31.4s, v17.16b, v19.16b + KAI_ASM_INST(0x4e92a63e) // smmla v30.4s, v17.16b, v18.16b + ldr q21, [x27, #0x40] + ldr q20, [x27, #0x50] + KAI_ASM_INST(0x4e93a61d) // smmla v29.4s, v16.16b, v19.16b + KAI_ASM_INST(0x4e92a61c) // smmla v28.4s, v16.16b, v18.16b + ldr q19, [x26, #0x60] + ldr q18, [x26, #0x70] + add x26, x26, #0x80 + ldr q17, [x27, #0x60] + ldr q16, [x27, #0x70] + add x27, x27, #0x80 + KAI_ASM_INST(0x4e9ba73f) // smmla v31.4s, v25.16b, v27.16b + KAI_ASM_INST(0x4e9aa73e) // smmla v30.4s, v25.16b, v26.16b + KAI_ASM_INST(0x4e9ba71d) // smmla v29.4s, v24.16b, v27.16b + KAI_ASM_INST(0x4e9aa71c) // smmla v28.4s, v24.16b, v26.16b + KAI_ASM_INST(0x4e97a6bf) // smmla v31.4s, v21.16b, v23.16b + KAI_ASM_INST(0x4e96a6be) // smmla v30.4s, v21.16b, v22.16b + KAI_ASM_INST(0x4e97a69d) // smmla v29.4s, v20.16b, v23.16b + KAI_ASM_INST(0x4e96a69c) // smmla v28.4s, v20.16b, v22.16b + KAI_ASM_INST(0x4e93a63f) // smmla v31.4s, v17.16b, v19.16b + KAI_ASM_INST(0x4e92a63e) // smmla v30.4s, v17.16b, v18.16b + KAI_ASM_INST(0x4e93a61d) // smmla v29.4s, v16.16b, v19.16b + KAI_ASM_INST(0x4e92a61c) // smmla v28.4s, v16.16b, v18.16b + bgt label_17 + ldr q18, [x26, #0x0] + ld1 { v17.4s }, [x27] + uzp1 v24.2d, v31.2d, v30.2d + uzp2 v23.2d, v31.2d, v30.2d + ldr q22, [x26, #0x10] + uzp1 v21.2d, v29.2d, v28.2d + uzp2 v20.2d, v29.2d, v28.2d + add x27, x27, #0x10 + ldr q16, [x27, #0x0] + add x26, x26, #0x20 + mla v24.4s, v18.4s, v17.s[0] + mla v23.4s, v18.4s, v17.s[1] + mla v21.4s, v18.4s, v17.s[2] + mla v20.4s, v18.4s, v17.s[3] + fmul v19.4s, v22.4s, v16.s[0] + fmul v18.4s, v22.4s, v16.s[1] + fmul v17.4s, v22.4s, v16.s[2] + fmul v16.4s, v22.4s, v16.s[3] + scvtf v24.4s, v24.4s + scvtf v23.4s, v23.4s + scvtf v21.4s, v21.4s + scvtf v20.4s, v20.4s + fmul v31.4s, v24.4s, v19.4s + fmul v30.4s, v23.4s, v18.4s + fmul v29.4s, v21.4s, v17.4s + fmul v28.4s, v20.4s, v16.4s + ldr q18, [x26, #0x0] + ld1r { v17.4s }, [x12] + add x20, x12, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v31.4s, v31.4s, v18.4s + fadd v30.4s, v30.4s, v18.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fmax v30.4s, v30.4s, v17.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmin v30.4s, v30.4s, v16.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + fcvtn v19.4h, v31.4s + fcvtn v18.4h, v30.4s + fcvtn v17.4h, v29.4s + fcvtn v16.4h, v28.4s + blt label_19 + mov x20, x15 + cmp x14, #0x1 + str d19, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x2 + str d18, [x20, #0x0] + add x20, x20, x13 + ble label_22 + cmp x14, #0x3 + str d17, [x20, #0x0] + add x20, x20, x13 + ble label_22 + str d16, [x20, #0x0] + b label_22 +KAI_ASM_LABEL(label_19) // Row tail: Partial output + mov x23, x15 + cmp x14, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x14, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x14, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_20 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x21], #0x4 + st1 { v18.s }[0], [x22], #0x4 + st1 { v19.s }[0], [x23], #0x4 + tbz x25, #0, label_21 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x21] + st1 { v18.h }[2], [x22] + st1 { v19.h }[2], [x23] + b label_21 +KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x21] + st1 { v18.h }[0], [x22] + st1 { v19.h }[0], [x23] +KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_22) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x15, x15, #0x8 + bgt label_16 + subs x14, x14, #0x4 + add x8, x8, x6 + mov x15, x24 + bgt label_15 +KAI_ASM_LABEL(label_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h new file mode 100644 index 00000000..8bfe31a6 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h @@ -0,0 +1,53 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f16_qai8dxp_qsi8cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f16_qai8dxp_qsi8cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f16_qai8dxp_qsi8cxp_ukernel { + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_mr_func_t get_mr; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_nr_func_t get_nr; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_kr_func_t get_kr; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_sr_func_t get_sr; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f16_qai8dxp_qsi8cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 9bfefbd2..4ee2046a 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -343,6 +343,15 @@ matmul_nt_t_quantized +matmul_nt_t_quantized( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + template std::vector indirect_matmul_nt_t_quantized( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // diff --git a/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp new file mode 100644 index 00000000..c993a8b4 --- /dev/null +++ b/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp @@ -0,0 +1,232 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/float16.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/pad.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +static auto cpu_has_dotprod_and_fp16 = []() { return cpu_has_dotprod() && cpu_has_fp16(); }; +static auto cpu_has_i8mm_and_fp16 = []() { return cpu_has_i8mm() && cpu_has_fp16(); }; + +static const std::array, 4> + variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp = {{ + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod), + "kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod), + "kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), + "kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, + {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm), + "kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm_and_fp16}, + }}; + +class MatMulTest_f16_qai8dxp_qsi8cxp : public ::testing::TestWithParam {}; + +TEST_P(MatMulTest_f16_qai8dxp_qsi8cxp, EndToEnd) { + const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const std::uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + if (mr == 1 && M > 1) { + GTEST_SKIP() << "Kernel does not support M != 1"; + } + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + // Generates input data. + const auto ref_lhs_f16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + std::vector ref_biases; + + if (has_bias) { + ref_biases = fill_random(N, seed + 2); + } + // For reference implementation, Casting FP16 input to FP32 type and FP32 output back to FP16 because the matmul + // implementation works with FP32 accumulation and casts the result to FP16 + const auto ref_lhs = cast(ref_lhs_f16.data(), ref_lhs_f16.size() * 8 / size_in_bits); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi8, ref_rhs_scales] = + quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); + + const auto ref_dst_no_clamp = + matmul_nt_t_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, + ref_rhs_qsi8.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases.data() : nullptr, nullptr, + nullptr, 1); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), M * N, clamp_ratio); + const auto ref_dst_float = clamp(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to F16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(uint16_t); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_f16_neon( + rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_f16.data() + lhs_offset, lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); + + const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(rhs_start_row, K, nr, kr, sr); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + // Runs the RHS packirng micro-kernel. + const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; + kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( + 1, N, K, nr, kr, sr, reinterpret_cast(ref_rhs_qsi8.data()), + has_bias ? reinterpret_cast(ref_biases.data()) : nullptr, + reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, + clamp_min, clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP16); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); +} +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_f16_qai8dxp_qsi8cxp, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.size()), + testing::Values( + MatMulShape{1, 2, 32}, // + MatMulShape{1, 3, 32}, // + MatMulShape{1, 4, 32}, // + MatMulShape{1, 5, 31}, // + MatMulShape{3, 3, 32}, // + MatMulShape{4, 4, 32}, // + MatMulShape{5, 5, 31}, // + MatMulShape{16, 32, 64}, // + MatMulShape{16, 32, 36}, // + MatMulShape{15, 35, 65}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 31, 45}, // + MatMulShape{1, 35, 65}, // + MatMulShape{1, 128, 32}, // + MatMulShape{64, 128, 32}, // + MatMulShape{77, 99, 64}), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + ), + testing::Bool()), + [](const auto& info) { + const auto variant_idx = std::get<0>(info.param); + const std::string name{variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.at(variant_idx).name}; + const auto shape = std::get(info.param); + const auto portion = std::get<2>(info.param); + const auto has_bias = std::get<3>(info.param); + + std::stringstream sstream; + sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << (has_bias ? "__Bias" : ""); + return sstream.str(); + }); + +} // namespace kai::test -- GitLab From ff4a6f9135d7b45043e589868c203849aeda17e2 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 14 Apr 2025 15:29:37 +0100 Subject: [PATCH 2/4] update bazel file to include new ukernels Signed-off-by: Evie Wright --- kai/ukernels/matmul/BUILD.bazel | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 0532f434..f3765e55 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -76,12 +76,16 @@ FP16_DOTPROD_KERNELS_ASM = [ "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod", "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod", "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod", + "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", + "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", + "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", ] # buildifier: keep sorted FP16_I8MM_KERNELS_ASM = [ "matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm", + "matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", "matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] -- GitLab From bf14e364c99bed186341f64704d4c02430f375bf Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 14 Apr 2025 16:46:43 +0100 Subject: [PATCH 3/4] pre-commit cleanup Signed-off-by: Evie Wright --- ...6_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c | 25 ++++++----- ...6_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h | 42 ++++++++++--------- ...6_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c | 25 ++++++----- ...6_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h | 42 ++++++++++--------- ..._qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c | 25 ++++++----- ..._qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h | 42 ++++++++++--------- ...f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c | 25 ++++++----- ...f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h | 42 ++++++++++--------- ...tmul_clamp_f16_qai8dxp_qsi8cxp_interface.h | 2 +- .../matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp | 2 +- 10 files changed, 138 insertions(+), 134 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c index 775cb7c3..4b5c099e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c @@ -25,8 +25,7 @@ typedef struct { size_t num_blocks; } KernelArgs; -void kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( - KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 1; @@ -40,7 +39,8 @@ static const size_t kai_sr = 1; static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; static const size_t kai_num_bytes_zp_lhs = 4; -// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) static const size_t kai_num_bytes_qvalue_rhs = 1; static const size_t kai_num_bytes_multiplier_rhs = 4; static const size_t kai_num_bytes_rsum_rhs = 4; @@ -127,17 +127,16 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( } void kai_run_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( - size_t m, // - size_t n, // - size_t k, // - const void* restrict lhs_packed, // - const void* restrict rhs_packed, // - void* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // float scalar_max) { - KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); if (m == 0) { diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h index 7f56d13b..73f7bcc5 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h @@ -65,8 +65,8 @@ size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void); /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( - size_t m_idx, // - size_t k); // + size_t m_idx, // + size_t k); // /// Gets the offset in bytes for the packed RHS matrix, /// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. @@ -76,8 +76,8 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon /// /// @return the offset in bytes to the packed RHS matrix size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( - size_t n_idx, // - size_t k); // + size_t n_idx, // + size_t k); // /// Gets the offset in bytes for the DST matrix /// @@ -87,9 +87,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon /// /// @return the DST offset in bytes size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( - size_t m_idx, // - size_t n_idx, // - size_t dst_stride); // + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // /// Gets the size in bytes for the destination (DST) matrix. /// @@ -114,25 +114,27 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( /// @param[in] m The number of output rows written. It must be 1. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. -/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. -/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( - size_t m, // - size_t n, // - size_t k, // - const void* lhs_packed, // - const void* rhs_packed, // - void* dst, // - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // - float scalar_max); // + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c index 75ec292f..a4c57930 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c @@ -25,8 +25,7 @@ typedef struct { size_t num_blocks; } KernelArgs; -void kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( - KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 1; @@ -40,7 +39,8 @@ static const size_t kai_sr = 1; static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; static const size_t kai_num_bytes_zp_lhs = 4; -// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) static const size_t kai_num_bytes_qvalue_rhs = 1; static const size_t kai_num_bytes_multiplier_rhs = 4; static const size_t kai_num_bytes_rsum_rhs = 4; @@ -127,17 +127,16 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( } void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( - size_t m, // - size_t n, // - size_t k, // - const void* restrict lhs_packed, // - const void* restrict rhs_packed, // - void* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // float scalar_max) { - KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); if (m == 0) { diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h index 660a3792..57c3e4a8 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h @@ -65,8 +65,8 @@ size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod(void); /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( - size_t m_idx, // - size_t k); // + size_t m_idx, // + size_t k); // /// Gets the offset in bytes for the packed RHS matrix, /// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. @@ -76,8 +76,8 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon /// /// @return the offset in bytes to the packed RHS matrix size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( - size_t n_idx, // - size_t k); // + size_t n_idx, // + size_t k); // /// Gets the offset in bytes for the DST matrix /// @@ -87,9 +87,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon /// /// @return the DST offset in bytes size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( - size_t m_idx, // - size_t n_idx, // - size_t dst_stride); // + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // /// Gets the size in bytes for the destination (DST) matrix. /// @@ -114,25 +114,27 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( /// @param[in] m The number of output rows written. It must be 1. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. -/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. -/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( - size_t m, // - size_t n, // - size_t k, // - const void* lhs_packed, // - const void* rhs_packed, // - void* dst, // - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // - float scalar_max); // + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c index f81e7d7b..bdb1c3dd 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c @@ -25,8 +25,7 @@ typedef struct { size_t num_blocks; } KernelArgs; -void kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( - KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 16; @@ -40,7 +39,8 @@ static const size_t kai_sr = 1; static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; static const size_t kai_num_bytes_zp_lhs = 4; -// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) static const size_t kai_num_bytes_qvalue_rhs = 1; static const size_t kai_num_bytes_multiplier_rhs = 4; static const size_t kai_num_bytes_rsum_rhs = 4; @@ -127,17 +127,16 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod } void kai_run_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( - size_t m, // - size_t n, // - size_t k, // - const void* restrict lhs_packed, // - const void* restrict rhs_packed, // - void* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // float scalar_max) { - KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); if (m == 0) { diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h index eb533351..f3a0887f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h @@ -65,8 +65,8 @@ size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod(void) /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( - size_t m_idx, // - size_t k); // + size_t m_idx, // + size_t k); // /// Gets the offset in bytes for the packed RHS matrix, /// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. @@ -76,8 +76,8 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neo /// /// @return the offset in bytes to the packed RHS matrix size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( - size_t n_idx, // - size_t k); // + size_t n_idx, // + size_t k); // /// Gets the offset in bytes for the DST matrix /// @@ -87,9 +87,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neo /// /// @return the DST offset in bytes size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( - size_t m_idx, // - size_t n_idx, // - size_t dst_stride); // + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // /// Gets the size in bytes for the destination (DST) matrix. /// @@ -114,25 +114,27 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod /// @param[in] m The number of output rows written. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. -/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. -/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( - size_t m, // - size_t n, // - size_t k, // - const void* lhs_packed, // - const void* rhs_packed, // - void* dst, // - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // - float scalar_max); // + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c index 49d6ba4e..2fe1feaf 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c @@ -25,8 +25,7 @@ typedef struct { size_t num_blocks; } KernelArgs; -void kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( - KernelArgs* args_ptr); +void kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(KernelArgs* args_ptr); // Compute args static const size_t kai_m_step = 16; @@ -40,7 +39,8 @@ static const size_t kai_sr = 1; static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; static const size_t kai_num_bytes_zp_lhs = 4; -// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is asymmetric)) +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) static const size_t kai_num_bytes_qvalue_rhs = 1; static const size_t kai_num_bytes_multiplier_rhs = 4; static const size_t kai_num_bytes_rsum_rhs = 4; @@ -127,17 +127,16 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(si } void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( - size_t m, // - size_t n, // - size_t k, // - const void* restrict lhs_packed, // - const void* restrict rhs_packed, // - void* restrict dst, // NOLINT(readability-non-const-parameter) - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // float scalar_max) { - KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); if (m == 0) { diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h index 31d3a8c2..329a1360 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h @@ -65,8 +65,8 @@ size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void); /// /// @return the offset in bytes to the packed LHS matrix size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( - size_t m_idx, // - size_t k); // + size_t m_idx, // + size_t k); // /// Gets the offset in bytes for the packed RHS matrix, /// which contains the packed Quantized Symmetric Signed 8-bit with per-channel quantization (qsi8cx) values. @@ -76,8 +76,8 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neo /// /// @return the offset in bytes to the packed RHS matrix size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( - size_t n_idx, // - size_t k); // + size_t n_idx, // + size_t k); // /// Gets the offset in bytes for the DST matrix /// @@ -87,9 +87,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neo /// /// @return the DST offset in bytes size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( - size_t m_idx, // - size_t n_idx, // - size_t dst_stride); // + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // /// Gets the size in bytes for the destination (DST) matrix. /// @@ -114,25 +114,27 @@ size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( /// @param[in] m The number of output rows written. /// @param[in] n The number of output columns written. /// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. -/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the top of this file. -/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the top of this file. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. /// @param[in] scalar_min Min value used to clamp the final result. /// @param[in] scalar_max Max value used to clamp the final result. void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( - size_t m, // - size_t n, // - size_t k, // - const void* lhs_packed, // - const void* rhs_packed, // - void* dst, // - size_t dst_stride_row, // - size_t dst_stride_col, // - float scalar_min, // - float scalar_max); // + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h index 8bfe31a6..f6ecd193 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h @@ -46,7 +46,7 @@ struct kai_matmul_clamp_f16_qai8dxp_qsi8cxp_ukernel { kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_dst_offset_func_t get_dst_offset; kai_matmul_clamp_f16_qai8dxp_qsi8cxp_get_dst_size_func_t get_dst_size; kai_matmul_clamp_f16_qai8dxp_qsi8cxp_run_matmul_func_t run_matmul; -}; +}; #ifdef __cplusplus } diff --git a/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp index c993a8b4..043c217e 100644 --- a/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp @@ -150,7 +150,7 @@ TEST_P(MatMulTest_f16_qai8dxp_qsi8cxp, EndToEnd) { auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(rhs_start_row, K, nr, kr, sr); auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); - + // Runs the RHS packirng micro-kernel. const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( -- GitLab From eb19a677bf317211829f38a20312ff11ae2fc85b Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 15 Apr 2025 16:04:02 +0100 Subject: [PATCH 4/4] update changelog Signed-off-by: Evie Wright --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5d8dd7a..ebbd0acc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- New Advanced SIMD micro-kernels: + - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_I8MM and FEAT_DotProd. + - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI8CX RHS with F16 output, optimized for FEAT_DotProd. + ## v1.7.0 - New SME micro-kernels: -- GitLab