From 508bb0ee3c8113f7ef148dbb362680f0dfc73957 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Fri, 11 Jul 2025 13:07:08 +0100 Subject: [PATCH 1/5] Matmul Micro-kernels BF16 <- (QAI8DXP) LHS x (QSI4C32P) RHS Signed-off-by: Evie Wright --- CHANGELOG.md | 2 + CMakeLists.txt | 6 + kai/ukernels/matmul/BUILD.bazel | 2 + ..._qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.c | 185 +++++ ..._qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h | 146 ++++ ...8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod_asm.S | 159 ++++ ...16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.c | 185 +++++ ...16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h | 146 ++++ ...ai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm_asm.S | 764 ++++++++++++++++++ ...ul_clamp_bf16_qai8dxp_qsi4c32p_interface.h | 53 ++ test/reference/matmul.cpp | 8 + ...atmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp | 448 ++++++++++ 12 files changed, 2104 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h create mode 100644 test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c8316ed..6ed596fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - New Advanced SIMD micro-kernels: - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_I8MM. - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_DotProd. + - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4C32 RHS with BF16 output, optimized for FEAT_I8MM. + - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4C32 RHS with BF16 output, optimized for FEAT_DotProd. - New SME micro-kernels: - Matrix multiplication (1xN) of F32 LHS and RHS with F32 output, using instructions compatible with FEAT_SME. - Matrix multiplication (1xN) of F16 LHS and RHS with F16 output, using instructions compatible with FEAT_SME. diff --git a/CMakeLists.txt b/CMakeLists.txt index d08faf2e..f1209371 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,6 +197,8 @@ set(KLEIDIAI_FILES_NEON_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_DOTPROD @@ -227,6 +229,8 @@ set(KLEIDIAI_FILES_NEON_I8MM_ASM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM @@ -461,6 +465,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp + test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp ) else() add_executable(kleidiai_test @@ -469,6 +474,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/float16_test.cpp test/tests/imatmul_test.cpp test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp + test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index ee19a706..498dbd76 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -111,6 +111,7 @@ DOTPROD_KERNELS = [ # buildifier: keep sorted DOTPROD_KERNELS_ASM = [ + "matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod", @@ -137,6 +138,7 @@ I8MM_KERNELS = [ # buildifier: keep sorted I8MM_KERNELS_ASM = [ + "matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.c new file mode 100644 index 00000000..785b5e46 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.c @@ -0,0 +1,185 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) +#error "Dotprod extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + void* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + const size_t num_subblocks = bl / kai_bl; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h new file mode 100644 index 00000000..bcfaa291 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h @@ -0,0 +1,146 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_bf16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (multiple of 32) quantization (qsi4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (multiple of 32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..6079e86f --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod_asm.S @@ -0,0 +1,159 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x15, #0x20 + movi v31.16b, #0xf0 + mov x21, #0x8 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + mul x15, x14, x15 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + madd x15, x13, x15, x21 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + movi v30.16b, #0x0 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v29.4s, #0x0 + movi v28.4s, #0x0 + mov x20, x14 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q27, [x25, #0x0] + ldr q26, [x25, #0x10] + subs x20, x20, #0x1 + ld1r { v25.2d }, [x22], #0x8 + ldr q24, [x25, #0x20] + ldr q23, [x25, #0x30] + add x25, x25, #0x40 + ld1r { v22.2d }, [x22], #0x8 + ld1r { v21.2d }, [x22], #0x8 + shl v20.16b, v27.16b, #0x4 + shl v19.16b, v26.16b, #0x4 + ld1r { v18.2d }, [x22], #0x8 + shl v17.16b, v24.16b, #0x4 + and v27.16b, v27.16b, v31.16b + shl v16.16b, v23.16b, #0x4 + and v26.16b, v26.16b, v31.16b + KAI_ASM_INST(0x4e99969d) // sdot v29.4s, v20.16b, v25.16b + KAI_ASM_INST(0x4e99967c) // sdot v28.4s, v19.16b, v25.16b + and v24.16b, v24.16b, v31.16b + and v23.16b, v23.16b, v31.16b + KAI_ASM_INST(0x4e96963d) // sdot v29.4s, v17.16b, v22.16b + KAI_ASM_INST(0x4e96961c) // sdot v28.4s, v16.16b, v22.16b + KAI_ASM_INST(0x4e95977d) // sdot v29.4s, v27.16b, v21.16b + KAI_ASM_INST(0x4e95975c) // sdot v28.4s, v26.16b, v21.16b + KAI_ASM_INST(0x4e92971d) // sdot v29.4s, v24.16b, v18.16b + KAI_ASM_INST(0x4e9296fc) // sdot v28.4s, v23.16b, v18.16b + bgt label_4 + ldr d16, [x25, #0x0] + addp v29.4s, v29.4s, v28.4s + sub x21, x21, #0x1 + add x25, x25, #0x8 + shll v16.4s, v16.4h, #0x10 + scvtf v29.4s, v29.4s, #0x4 + fmla v30.4s, v29.4s, v16.4s + cbnz x21, label_3 + ld1r { v21.4s }, [x22] + ldr q20, [x25, #0x0] + add x22, x22, #0x4 + add x20, x27, #0x4 + ld1r { v19.4s }, [x22] + ldr q18, [x25, #0x10] + cmp x24, #0x4 + add x25, x25, #0x20 + ld1r { v17.4s }, [x27] + ld1r { v16.4s }, [x20] + scvtf v21.4s, v21.4s + fmla v30.4s, v20.4s, v21.s[0] + fmul v30.4s, v30.4s, v19.4s + fadd v30.4s, v30.4s, v18.4s + fmax v30.4s, v30.4s, v17.4s + fmin v30.4s, v30.4s, v16.4s + KAI_ASM_INST(0x0ea16bd0) // bfcvtn v16.4h, v30.4s + blt label_5 + str d16, [x9, #0x0] + b label_8 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #1, label_6 + st1 { v16.s }[0], [x20], #0x4 + tbz x24, #0, label_7 + st1 { v16.h }[2], [x20] + b label_7 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_7) // Output block 0: Done +KAI_ASM_LABEL(label_8) // Stores done + subs x24, x24, #0x4 + add x9, x9, #0x8 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.c new file mode 100644 index 00000000..d327b7e2 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.c @@ -0,0 +1,185 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) +#error "I8mm extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + void* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 2; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs; + return num_bytes_per_block_rhs; +} + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSUME((bl % kai_bl) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); + + size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row); + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kai_bl) == 0); + + if (m == 0) { + return; + } + const size_t num_subblocks = bl / kai_bl; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + args.num_subblocks = num_subblocks; + + kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h new file mode 100644 index 00000000..d84b6b9c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h @@ -0,0 +1,146 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_bf16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-block (multiple of 32) quantization (qsi4c32) +/// values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-block (multiple of 32) quantization (qsi4c32) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// It must be a multiple of the block length (bl). +/// @param[in] bl Block length. Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm_asm.S new file mode 100644 index 00000000..a11b5f00 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm_asm.S @@ -0,0 +1,764 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x5, #0x80 + mov x21, #0x20 + sub SP, SP, #0x100 + ldr x20, [x0, #0x28] + ldr x6, [x0, #0x40] + ldr x7, [x0, #0x38] + ldr x8, [x0, #0x8] + ldr x17, [x0, #0x10] + ldr x16, [x0, #0x30] + mov x15, x20 + mul x5, x6, x5 + ldr x14, [x0, #0x0] + ldr x13, [x0, #0x20] + ldr x12, [x0, #0x18] + cmp x15, #0x10 + madd x5, x7, x5, x21 + blt label_15 +KAI_ASM_LABEL(label_1) // Row loop + mov x11, x17 + mov x10, x16 + add x9, x14, x13, LSL #4 +KAI_ASM_LABEL(label_2) // Column loop + mov x27, x8 + movi v6.4s, #0x0 + mov x24, x7 + str q6, [SP, #0x0] + str q6, [SP, #0x10] + str q6, [SP, #0x20] + add x23, x27, x5 + add x22, x23, x5 + str q6, [SP, #0x30] + add x21, x22, x5 + str q6, [SP, #0x40] + str q6, [SP, #0x50] + str q6, [SP, #0x60] + str q6, [SP, #0x70] + str q6, [SP, #0x80] + str q6, [SP, #0x90] + str q6, [SP, #0xa0] + str q6, [SP, #0xb0] + str q6, [SP, #0xc0] + str q6, [SP, #0xd0] + str q6, [SP, #0xe0] + str q6, [SP, #0xf0] +KAI_ASM_LABEL(label_3) // Block loop + movi v14.4s, #0x0 + movi v24.4s, #0x0 + mov x20, x6 + movi v13.4s, #0x0 + movi v11.4s, #0x0 + movi v15.4s, #0x0 + movi v22.4s, #0x0 + movi v12.4s, #0x0 + movi v4.4s, #0x0 + movi v25.4s, #0x0 + movi v17.4s, #0x0 + movi v5.4s, #0x0 + movi v10.4s, #0x0 + movi v20.4s, #0x0 + movi v9.4s, #0x0 + movi v27.4s, #0x0 + movi v18.4s, #0x0 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q8, [x11, #0x0] + ldr q3, [x11, #0x10] + movi v16.16b, #0xf0 + subs x20, x20, #0x1 + ldr q31, [x27, #0x0] + ldr q2, [x27, #0x10] + ldr q28, [x23, #0x0] + ldr q23, [x23, #0x10] + ldr q26, [x22, #0x0] + ldr q30, [x22, #0x10] + shl v21.16b, v8.16b, #0x4 + shl v1.16b, v3.16b, #0x4 + ldr q7, [x21, #0x0] + ldr q19, [x21, #0x10] + and v8.16b, v8.16b, v16.16b + and v3.16b, v3.16b, v16.16b + ldr q0, [x11, #0x20] + ldr q29, [x11, #0x30] + add x11, x11, #0x40 + ldr q6, [x27, #0x20] + KAI_ASM_INST(0x4e95a7ee) // smmla v14.4s, v31.16b, v21.16b + KAI_ASM_INST(0x4e81a7f8) // smmla v24.4s, v31.16b, v1.16b + ldr q31, [x27, #0x30] + KAI_ASM_INST(0x4e95a44d) // smmla v13.4s, v2.16b, v21.16b + KAI_ASM_INST(0x4e81a44b) // smmla v11.4s, v2.16b, v1.16b + ldr q2, [x23, #0x20] + KAI_ASM_INST(0x4e95a78f) // smmla v15.4s, v28.16b, v21.16b + KAI_ASM_INST(0x4e81a796) // smmla v22.4s, v28.16b, v1.16b + ldr q28, [x23, #0x30] + KAI_ASM_INST(0x4e95a6ec) // smmla v12.4s, v23.16b, v21.16b + KAI_ASM_INST(0x4e81a6e4) // smmla v4.4s, v23.16b, v1.16b + ldr q23, [x22, #0x20] + KAI_ASM_INST(0x4e95a759) // smmla v25.4s, v26.16b, v21.16b + KAI_ASM_INST(0x4e81a751) // smmla v17.4s, v26.16b, v1.16b + ldr q26, [x22, #0x30] + KAI_ASM_INST(0x4e95a7c5) // smmla v5.4s, v30.16b, v21.16b + KAI_ASM_INST(0x4e81a7ca) // smmla v10.4s, v30.16b, v1.16b + ldr q30, [x21, #0x20] + KAI_ASM_INST(0x4e95a4f4) // smmla v20.4s, v7.16b, v21.16b + KAI_ASM_INST(0x4e81a4e9) // smmla v9.4s, v7.16b, v1.16b + ldr q7, [x21, #0x30] + KAI_ASM_INST(0x4e95a67b) // smmla v27.4s, v19.16b, v21.16b + ldr q21, [x27, #0x40] + KAI_ASM_INST(0x4e81a672) // smmla v18.4s, v19.16b, v1.16b + ldr q19, [x27, #0x50] + shl v1.16b, v0.16b, #0x4 + and v0.16b, v0.16b, v16.16b + KAI_ASM_INST(0x4e81a4ce) // smmla v14.4s, v6.16b, v1.16b + KAI_ASM_INST(0x4e81a7ed) // smmla v13.4s, v31.16b, v1.16b + KAI_ASM_INST(0x4e81a44f) // smmla v15.4s, v2.16b, v1.16b + KAI_ASM_INST(0x4e81a78c) // smmla v12.4s, v28.16b, v1.16b + KAI_ASM_INST(0x4e81a6f9) // smmla v25.4s, v23.16b, v1.16b + KAI_ASM_INST(0x4e81a745) // smmla v5.4s, v26.16b, v1.16b + KAI_ASM_INST(0x4e81a7d4) // smmla v20.4s, v30.16b, v1.16b + KAI_ASM_INST(0x4e81a4fb) // smmla v27.4s, v7.16b, v1.16b + shl v1.16b, v29.16b, #0x4 + KAI_ASM_INST(0x4e88a6ae) // smmla v14.4s, v21.16b, v8.16b + KAI_ASM_INST(0x4e88a66d) // smmla v13.4s, v19.16b, v8.16b + and v29.16b, v29.16b, v16.16b + ldr q16, [x23, #0x40] + KAI_ASM_INST(0x4e81a4d8) // smmla v24.4s, v6.16b, v1.16b + ldr q6, [x23, #0x50] + KAI_ASM_INST(0x4e81a7eb) // smmla v11.4s, v31.16b, v1.16b + ldr q31, [x22, #0x40] + KAI_ASM_INST(0x4e81a456) // smmla v22.4s, v2.16b, v1.16b + ldr q2, [x22, #0x50] + KAI_ASM_INST(0x4e81a784) // smmla v4.4s, v28.16b, v1.16b + ldr q28, [x21, #0x40] + KAI_ASM_INST(0x4e81a6f1) // smmla v17.4s, v23.16b, v1.16b + ldr q23, [x21, #0x50] + KAI_ASM_INST(0x4e81a74a) // smmla v10.4s, v26.16b, v1.16b + ldr q26, [x27, #0x60] + KAI_ASM_INST(0x4e81a7c9) // smmla v9.4s, v30.16b, v1.16b + ldr q30, [x27, #0x70] + KAI_ASM_INST(0x4e81a4f2) // smmla v18.4s, v7.16b, v1.16b + ldr q1, [x23, #0x60] + KAI_ASM_INST(0x4e83a6b8) // smmla v24.4s, v21.16b, v3.16b + ldr q21, [x23, #0x70] + ldr q7, [x22, #0x60] + KAI_ASM_INST(0x4e83a66b) // smmla v11.4s, v19.16b, v3.16b + KAI_ASM_INST(0x4e88a60f) // smmla v15.4s, v16.16b, v8.16b + ldr q19, [x22, #0x70] + KAI_ASM_INST(0x4e83a616) // smmla v22.4s, v16.16b, v3.16b + ldr q16, [x21, #0x60] + KAI_ASM_INST(0x4e88a4cc) // smmla v12.4s, v6.16b, v8.16b + KAI_ASM_INST(0x4e83a4c4) // smmla v4.4s, v6.16b, v3.16b + ldr q6, [x21, #0x70] + KAI_ASM_INST(0x4e88a7f9) // smmla v25.4s, v31.16b, v8.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4e83a7f1) // smmla v17.4s, v31.16b, v3.16b + KAI_ASM_INST(0x4e88a445) // smmla v5.4s, v2.16b, v8.16b + add x23, x23, #0x80 + add x22, x22, #0x80 + KAI_ASM_INST(0x4e83a44a) // smmla v10.4s, v2.16b, v3.16b + KAI_ASM_INST(0x4e88a794) // smmla v20.4s, v28.16b, v8.16b + add x21, x21, #0x80 + KAI_ASM_INST(0x4e83a789) // smmla v9.4s, v28.16b, v3.16b + KAI_ASM_INST(0x4e88a6fb) // smmla v27.4s, v23.16b, v8.16b + KAI_ASM_INST(0x4e83a6f2) // smmla v18.4s, v23.16b, v3.16b + KAI_ASM_INST(0x4e80a74e) // smmla v14.4s, v26.16b, v0.16b + KAI_ASM_INST(0x4e9da758) // smmla v24.4s, v26.16b, v29.16b + KAI_ASM_INST(0x4e80a7cd) // smmla v13.4s, v30.16b, v0.16b + KAI_ASM_INST(0x4e9da7cb) // smmla v11.4s, v30.16b, v29.16b + KAI_ASM_INST(0x4e80a42f) // smmla v15.4s, v1.16b, v0.16b + KAI_ASM_INST(0x4e9da436) // smmla v22.4s, v1.16b, v29.16b + KAI_ASM_INST(0x4e80a6ac) // smmla v12.4s, v21.16b, v0.16b + KAI_ASM_INST(0x4e9da6a4) // smmla v4.4s, v21.16b, v29.16b + KAI_ASM_INST(0x4e80a4f9) // smmla v25.4s, v7.16b, v0.16b + KAI_ASM_INST(0x4e9da4f1) // smmla v17.4s, v7.16b, v29.16b + KAI_ASM_INST(0x4e80a665) // smmla v5.4s, v19.16b, v0.16b + KAI_ASM_INST(0x4e9da66a) // smmla v10.4s, v19.16b, v29.16b + KAI_ASM_INST(0x4e80a614) // smmla v20.4s, v16.16b, v0.16b + KAI_ASM_INST(0x4e9da609) // smmla v9.4s, v16.16b, v29.16b + KAI_ASM_INST(0x4e80a4db) // smmla v27.4s, v6.16b, v0.16b + KAI_ASM_INST(0x4e9da4d2) // smmla v18.4s, v6.16b, v29.16b + bgt label_4 + ldr d19, [x11, #0x0] + ldr q28, [SP, #0x0] + uzp1 v1.2d, v14.2d, v24.2d + uzp2 v23.2d, v14.2d, v24.2d + ldr q21, [SP, #0x10] + ldr q6, [SP, #0x20] + uzp1 v31.2d, v13.2d, v11.2d + uzp2 v24.2d, v13.2d, v11.2d + ldr q16, [SP, #0x30] + add x11, x11, #0x8 + shll v26.4s, v19.4h, #0x10 + scvtf v1.4s, v1.4s, #0x4 + scvtf v23.4s, v23.4s, #0x4 + scvtf v31.4s, v31.4s, #0x4 + scvtf v24.4s, v24.4s, #0x4 + fmla v28.4s, v1.4s, v26.4s + fmla v21.4s, v23.4s, v26.4s + fmla v6.4s, v31.4s, v26.4s + fmla v16.4s, v24.4s, v26.4s + str q28, [SP, #0x0] + str q21, [SP, #0x10] + str q6, [SP, #0x20] + str q16, [SP, #0x30] + ldr q11, [SP, #0x40] + ldr q8, [SP, #0x50] + uzp1 v23.2d, v15.2d, v22.2d + uzp2 v21.2d, v15.2d, v22.2d + ldr q19, [SP, #0x60] + ldr q2, [SP, #0x70] + uzp1 v3.2d, v12.2d, v4.2d + uzp2 v16.2d, v12.2d, v4.2d + scvtf v23.4s, v23.4s, #0x4 + scvtf v21.4s, v21.4s, #0x4 + scvtf v3.4s, v3.4s, #0x4 + scvtf v16.4s, v16.4s, #0x4 + fmla v11.4s, v23.4s, v26.4s + fmla v8.4s, v21.4s, v26.4s + fmla v19.4s, v3.4s, v26.4s + fmla v2.4s, v16.4s, v26.4s + str q11, [SP, #0x40] + str q8, [SP, #0x50] + str q19, [SP, #0x60] + str q2, [SP, #0x70] + ldr q8, [SP, #0x80] + ldr q23, [SP, #0x90] + uzp1 v3.2d, v25.2d, v17.2d + uzp2 v21.2d, v25.2d, v17.2d + ldr q19, [SP, #0xa0] + ldr q4, [SP, #0xb0] + uzp1 v24.2d, v5.2d, v10.2d + uzp2 v16.2d, v5.2d, v10.2d + scvtf v3.4s, v3.4s, #0x4 + scvtf v21.4s, v21.4s, #0x4 + scvtf v24.4s, v24.4s, #0x4 + scvtf v16.4s, v16.4s, #0x4 + fmla v8.4s, v3.4s, v26.4s + fmla v23.4s, v21.4s, v26.4s + fmla v19.4s, v24.4s, v26.4s + fmla v4.4s, v16.4s, v26.4s + str q8, [SP, #0x80] + str q23, [SP, #0x90] + str q19, [SP, #0xa0] + str q4, [SP, #0xb0] + ldr q23, [SP, #0xc0] + ldr q22, [SP, #0xd0] + uzp1 v21.2d, v20.2d, v9.2d + uzp2 v20.2d, v20.2d, v9.2d + ldr q19, [SP, #0xe0] + ldr q8, [SP, #0xf0] + uzp1 v4.2d, v27.2d, v18.2d + uzp2 v16.2d, v27.2d, v18.2d + scvtf v21.4s, v21.4s, #0x4 + scvtf v20.4s, v20.4s, #0x4 + scvtf v4.4s, v4.4s, #0x4 + scvtf v16.4s, v16.4s, #0x4 + fmla v23.4s, v21.4s, v26.4s + fmla v22.4s, v20.4s, v26.4s + fmla v19.4s, v4.4s, v26.4s + fmla v8.4s, v16.4s, v26.4s + str q23, [SP, #0xc0] + str q22, [SP, #0xd0] + str q19, [SP, #0xe0] + str q8, [SP, #0xf0] + subs x24, x24, #0x1 + bgt label_3 + ld1 { v11.4s }, [x27] + ld1 { v10.4s }, [x23] + add x27, x27, #0x10 + add x23, x23, #0x10 + ld1 { v9.4s }, [x22] + ld1 { v8.4s }, [x21] + add x22, x22, #0x10 + add x21, x21, #0x10 + ldr q31, [SP, #0x0] + ldr q30, [SP, #0x10] + add x20, x12, #0x4 + cmp x10, #0x4 + ldr q29, [SP, #0x20] + ldr q28, [SP, #0x30] + scvtf v11.4s, v11.4s + scvtf v10.4s, v10.4s + ldr q27, [SP, #0x40] + ldr q26, [SP, #0x50] + scvtf v9.4s, v9.4s + scvtf v8.4s, v8.4s + ldr q25, [SP, #0x60] + ldr q24, [SP, #0x70] + ldr q23, [SP, #0x80] + ldr q22, [SP, #0x90] + ldr q21, [SP, #0xa0] + ldr q20, [SP, #0xb0] + ldr q19, [SP, #0xc0] + ldr q18, [SP, #0xd0] + ldr q17, [SP, #0xe0] + ldr q16, [SP, #0xf0] + ldr q7, [x11, #0x0] + ldr q6, [x27, #0x0] + ldr q5, [x23, #0x0] + ldr q4, [x22, #0x0] + ldr q3, [x21, #0x0] + ldr q2, [x11, #0x10] + add x11, x11, #0x20 + ld1r { v1.4s }, [x12] + ld1r { v0.4s }, [x20] + fmla v31.4s, v7.4s, v11.s[0] + fmla v30.4s, v7.4s, v11.s[1] + fmla v29.4s, v7.4s, v11.s[2] + fmla v28.4s, v7.4s, v11.s[3] + fmla v27.4s, v7.4s, v10.s[0] + fmla v26.4s, v7.4s, v10.s[1] + fmla v25.4s, v7.4s, v10.s[2] + fmla v24.4s, v7.4s, v10.s[3] + fmla v23.4s, v7.4s, v9.s[0] + fmla v22.4s, v7.4s, v9.s[1] + fmul v31.4s, v31.4s, v6.s[0] + fmla v21.4s, v7.4s, v9.s[2] + fmla v20.4s, v7.4s, v9.s[3] + fmul v30.4s, v30.4s, v6.s[1] + fmla v19.4s, v7.4s, v8.s[0] + fmla v18.4s, v7.4s, v8.s[1] + fmul v29.4s, v29.4s, v6.s[2] + fmla v17.4s, v7.4s, v8.s[2] + fmla v16.4s, v7.4s, v8.s[3] + fmul v28.4s, v28.4s, v6.s[3] + fmul v27.4s, v27.4s, v5.s[0] + fmul v26.4s, v26.4s, v5.s[1] + fmul v25.4s, v25.4s, v5.s[2] + fmul v24.4s, v24.4s, v5.s[3] + fmul v23.4s, v23.4s, v4.s[0] + fmul v22.4s, v22.4s, v4.s[1] + fmul v21.4s, v21.4s, v4.s[2] + fmul v20.4s, v20.4s, v4.s[3] + fmul v19.4s, v19.4s, v3.s[0] + fmul v18.4s, v18.4s, v3.s[1] + fmul v17.4s, v17.4s, v3.s[2] + fmul v16.4s, v16.4s, v3.s[3] + fadd v31.4s, v31.4s, v2.4s + fadd v30.4s, v30.4s, v2.4s + fadd v29.4s, v29.4s, v2.4s + fadd v28.4s, v28.4s, v2.4s + fadd v27.4s, v27.4s, v2.4s + fadd v26.4s, v26.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + fadd v24.4s, v24.4s, v2.4s + fadd v23.4s, v23.4s, v2.4s + fadd v22.4s, v22.4s, v2.4s + fadd v21.4s, v21.4s, v2.4s + fadd v20.4s, v20.4s, v2.4s + fadd v19.4s, v19.4s, v2.4s + fadd v18.4s, v18.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + fadd v16.4s, v16.4s, v2.4s + fmax v31.4s, v31.4s, v1.4s + fmax v30.4s, v30.4s, v1.4s + fmax v29.4s, v29.4s, v1.4s + fmax v28.4s, v28.4s, v1.4s + fmax v27.4s, v27.4s, v1.4s + fmax v26.4s, v26.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + fmax v24.4s, v24.4s, v1.4s + fmax v23.4s, v23.4s, v1.4s + fmax v22.4s, v22.4s, v1.4s + fmax v21.4s, v21.4s, v1.4s + fmax v20.4s, v20.4s, v1.4s + fmax v19.4s, v19.4s, v1.4s + fmax v18.4s, v18.4s, v1.4s + fmax v17.4s, v17.4s, v1.4s + fmax v16.4s, v16.4s, v1.4s + fmin v31.4s, v31.4s, v0.4s + fmin v30.4s, v30.4s, v0.4s + fmin v29.4s, v29.4s, v0.4s + fmin v28.4s, v28.4s, v0.4s + fmin v27.4s, v27.4s, v0.4s + fmin v26.4s, v26.4s, v0.4s + fmin v25.4s, v25.4s, v0.4s + fmin v24.4s, v24.4s, v0.4s + fmin v23.4s, v23.4s, v0.4s + fmin v22.4s, v22.4s, v0.4s + fmin v21.4s, v21.4s, v0.4s + fmin v20.4s, v20.4s, v0.4s + fmin v19.4s, v19.4s, v0.4s + fmin v18.4s, v18.4s, v0.4s + fmin v17.4s, v17.4s, v0.4s + fmin v16.4s, v16.4s, v0.4s + KAI_ASM_INST(0x0ea16bff) // bfcvtn v31.4h, v31.4s + KAI_ASM_INST(0x0ea16bde) // bfcvtn v30.4h, v30.4s + KAI_ASM_INST(0x0ea16bbd) // bfcvtn v29.4h, v29.4s + KAI_ASM_INST(0x0ea16b9c) // bfcvtn v28.4h, v28.4s + KAI_ASM_INST(0x0ea16b7b) // bfcvtn v27.4h, v27.4s + KAI_ASM_INST(0x0ea16b5a) // bfcvtn v26.4h, v26.4s + KAI_ASM_INST(0x0ea16b39) // bfcvtn v25.4h, v25.4s + KAI_ASM_INST(0x0ea16b18) // bfcvtn v24.4h, v24.4s + KAI_ASM_INST(0x0ea16af7) // bfcvtn v23.4h, v23.4s + KAI_ASM_INST(0x0ea16ad6) // bfcvtn v22.4h, v22.4s + KAI_ASM_INST(0x0ea16ab5) // bfcvtn v21.4h, v21.4s + KAI_ASM_INST(0x0ea16a94) // bfcvtn v20.4h, v20.4s + KAI_ASM_INST(0x0ea16a73) // bfcvtn v19.4h, v19.4s + KAI_ASM_INST(0x0ea16a52) // bfcvtn v18.4h, v18.4s + KAI_ASM_INST(0x0ea16a31) // bfcvtn v17.4h, v17.4s + KAI_ASM_INST(0x0ea16a10) // bfcvtn v16.4h, v16.4s + blt label_9 + mov x20, x14 + str d31, [x20, #0x0] + add x20, x20, x13 + str d30, [x20, #0x0] + add x20, x20, x13 + str d29, [x20, #0x0] + add x20, x20, x13 + str d28, [x20, #0x0] + add x20, x20, x13 + str d27, [x20, #0x0] + add x20, x20, x13 + str d26, [x20, #0x0] + add x20, x20, x13 + str d25, [x20, #0x0] + add x20, x20, x13 + str d24, [x20, #0x0] + add x20, x20, x13 + str d23, [x20, #0x0] + add x20, x20, x13 + str d22, [x20, #0x0] + add x20, x20, x13 + str d21, [x20, #0x0] + add x20, x20, x13 + str d20, [x20, #0x0] + add x20, x20, x13 + str d19, [x20, #0x0] + add x20, x20, x13 + str d18, [x20, #0x0] + add x20, x20, x13 + str d17, [x20, #0x0] + add x20, x20, x13 + str d16, [x20, #0x0] + b label_14 +KAI_ASM_LABEL(label_9) // Partial output + mov x28, x14 + add x26, x28, x13, LSL #2 + add x25, x26, x13, LSL #1 + add x24, x26, x13 + add x23, x25, x13 + add x22, x28, x13, LSL #1 + add x21, x28, x13 + add x20, x22, x13 + add x27, x23, x13 + tbz x10, #1, label_10 + st1 { v24.s }[0], [x23], #0x4 + st1 { v25.s }[0], [x25], #0x4 + st1 { v26.s }[0], [x24], #0x4 + st1 { v27.s }[0], [x26], #0x4 + st1 { v28.s }[0], [x20], #0x4 + st1 { v29.s }[0], [x22], #0x4 + st1 { v30.s }[0], [x21], #0x4 + st1 { v31.s }[0], [x28], #0x4 + tbz x10, #0, label_11 + st1 { v24.h }[2], [x23] + st1 { v25.h }[2], [x25] + st1 { v26.h }[2], [x24] + st1 { v27.h }[2], [x26] + st1 { v28.h }[2], [x20] + st1 { v29.h }[2], [x22] + st1 { v30.h }[2], [x21] + st1 { v31.h }[2], [x28] + b label_11 +KAI_ASM_LABEL(label_10) // Output block 0: partial_1_0 + st1 { v24.h }[0], [x23] + st1 { v25.h }[0], [x25] + st1 { v26.h }[0], [x24] + st1 { v27.h }[0], [x26] + st1 { v28.h }[0], [x20] + st1 { v29.h }[0], [x22] + st1 { v30.h }[0], [x21] + st1 { v31.h }[0], [x28] +KAI_ASM_LABEL(label_11) // Output block 0: Done + add x26, x27, x13, LSL #2 + add x25, x27, x13, LSL #1 + add x24, x26, x13, LSL #1 + add x23, x27, x13 + add x22, x25, x13 + add x21, x26, x13 + add x20, x24, x13 + tbz x10, #1, label_12 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x24], #0x4 + st1 { v18.s }[0], [x21], #0x4 + st1 { v19.s }[0], [x26], #0x4 + st1 { v20.s }[0], [x22], #0x4 + st1 { v21.s }[0], [x25], #0x4 + st1 { v22.s }[0], [x23], #0x4 + st1 { v23.s }[0], [x27], #0x4 + tbz x10, #0, label_13 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x24] + st1 { v18.h }[2], [x21] + st1 { v19.h }[2], [x26] + st1 { v20.h }[2], [x22] + st1 { v21.h }[2], [x25] + st1 { v22.h }[2], [x23] + st1 { v23.h }[2], [x27] + b label_13 +KAI_ASM_LABEL(label_12) // Output block 1: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x24] + st1 { v18.h }[0], [x21] + st1 { v19.h }[0], [x26] + st1 { v20.h }[0], [x22] + st1 { v21.h }[0], [x25] + st1 { v22.h }[0], [x23] + st1 { v23.h }[0], [x27] +KAI_ASM_LABEL(label_13) // Output block 1: Done +KAI_ASM_LABEL(label_14) // Output stage exit + subs x10, x10, #0x4 + add x14, x14, #0x8 + bgt label_2 + mov x20, #0x4 + sub x15, x15, #0x10 + cmp x15, #0x10 + mov x14, x9 + madd x8, x20, x5, x8 + bge label_1 +KAI_ASM_LABEL(label_15) // Row loop skip + cbz x15, label_25 +KAI_ASM_LABEL(label_16) // Row tail: Row loop + mov x26, x17 + mov x25, x16 + add x24, x14, x13, LSL #2 +KAI_ASM_LABEL(label_17) // Row tail: Column loop + movi v16.4s, #0x0 + mov x27, x8 + mov x21, x7 + str q16, [SP, #0x0] + str q16, [SP, #0x10] + str q16, [SP, #0x20] + str q16, [SP, #0x30] +KAI_ASM_LABEL(label_18) // Row tail: Block loop + movi v14.4s, #0x0 + movi v24.4s, #0x0 + mov x20, x6 + movi v13.4s, #0x0 + movi v11.4s, #0x0 +KAI_ASM_LABEL(label_19) // Row tail: Sub block loop + ldr q2, [x26, #0x0] + ldr q1, [x26, #0x10] + movi v0.16b, #0xf0 + subs x20, x20, #0x1 + ldr q31, [x27, #0x0] + ldr q30, [x27, #0x10] + ldr q15, [x26, #0x20] + ldr q28, [x26, #0x30] + add x26, x26, #0x40 + ldr q27, [x27, #0x20] + ldr q26, [x27, #0x30] + shl v29.16b, v2.16b, #0x4 + shl v25.16b, v1.16b, #0x4 + ldr q23, [x27, #0x40] + ldr q22, [x27, #0x50] + and v2.16b, v2.16b, v0.16b + and v1.16b, v1.16b, v0.16b + ldr q21, [x27, #0x60] + ldr q20, [x27, #0x70] + shl v19.16b, v15.16b, #0x4 + shl v18.16b, v28.16b, #0x4 + KAI_ASM_INST(0x4e9da7ee) // smmla v14.4s, v31.16b, v29.16b + KAI_ASM_INST(0x4e99a7f8) // smmla v24.4s, v31.16b, v25.16b + and v15.16b, v15.16b, v0.16b + add x27, x27, #0x80 + KAI_ASM_INST(0x4e9da7cd) // smmla v13.4s, v30.16b, v29.16b + KAI_ASM_INST(0x4e99a7cb) // smmla v11.4s, v30.16b, v25.16b + and v28.16b, v28.16b, v0.16b + KAI_ASM_INST(0x4e93a76e) // smmla v14.4s, v27.16b, v19.16b + KAI_ASM_INST(0x4e92a778) // smmla v24.4s, v27.16b, v18.16b + KAI_ASM_INST(0x4e93a74d) // smmla v13.4s, v26.16b, v19.16b + KAI_ASM_INST(0x4e92a74b) // smmla v11.4s, v26.16b, v18.16b + KAI_ASM_INST(0x4e82a6ee) // smmla v14.4s, v23.16b, v2.16b + KAI_ASM_INST(0x4e81a6f8) // smmla v24.4s, v23.16b, v1.16b + KAI_ASM_INST(0x4e82a6cd) // smmla v13.4s, v22.16b, v2.16b + KAI_ASM_INST(0x4e81a6cb) // smmla v11.4s, v22.16b, v1.16b + KAI_ASM_INST(0x4e8fa6ae) // smmla v14.4s, v21.16b, v15.16b + KAI_ASM_INST(0x4e9ca6b8) // smmla v24.4s, v21.16b, v28.16b + KAI_ASM_INST(0x4e8fa68d) // smmla v13.4s, v20.16b, v15.16b + KAI_ASM_INST(0x4e9ca68b) // smmla v11.4s, v20.16b, v28.16b + bgt label_19 + ldr d26, [x26, #0x0] + ldr q25, [SP, #0x0] + uzp1 v23.2d, v14.2d, v24.2d + uzp2 v22.2d, v14.2d, v24.2d + ldr q21, [SP, #0x10] + ldr q20, [SP, #0x20] + uzp1 v19.2d, v13.2d, v11.2d + uzp2 v18.2d, v13.2d, v11.2d + ldr q17, [SP, #0x30] + add x26, x26, #0x8 + shll v16.4s, v26.4h, #0x10 + scvtf v23.4s, v23.4s, #0x4 + scvtf v22.4s, v22.4s, #0x4 + scvtf v19.4s, v19.4s, #0x4 + scvtf v18.4s, v18.4s, #0x4 + fmla v25.4s, v23.4s, v16.4s + fmla v21.4s, v22.4s, v16.4s + fmla v20.4s, v19.4s, v16.4s + fmla v17.4s, v18.4s, v16.4s + str q25, [SP, #0x0] + str q21, [SP, #0x10] + str q20, [SP, #0x20] + str q17, [SP, #0x30] + subs x21, x21, #0x1 + bgt label_18 + ld1 { v21.4s }, [x27] + ldr q31, [SP, #0x0] + add x27, x27, #0x10 + add x20, x12, #0x4 + ldr q30, [SP, #0x10] + ldr q29, [SP, #0x20] + cmp x25, #0x4 + ldr q28, [SP, #0x30] + ldr q20, [x26, #0x0] + ldr q19, [x27, #0x0] + ldr q18, [x26, #0x10] + scvtf v21.4s, v21.4s + add x26, x26, #0x20 + ld1r { v17.4s }, [x12] + ld1r { v16.4s }, [x20] + fmla v31.4s, v20.4s, v21.s[0] + fmla v30.4s, v20.4s, v21.s[1] + fmla v29.4s, v20.4s, v21.s[2] + fmla v28.4s, v20.4s, v21.s[3] + fmul v31.4s, v31.4s, v19.s[0] + fmul v30.4s, v30.4s, v19.s[1] + fadd v31.4s, v31.4s, v18.4s + fmul v29.4s, v29.4s, v19.s[2] + fmul v28.4s, v28.4s, v19.s[3] + fadd v30.4s, v30.4s, v18.4s + fmax v31.4s, v31.4s, v17.4s + fadd v29.4s, v29.4s, v18.4s + fadd v28.4s, v28.4s, v18.4s + fmax v30.4s, v30.4s, v17.4s + fmin v31.4s, v31.4s, v16.4s + fmax v29.4s, v29.4s, v17.4s + fmax v28.4s, v28.4s, v17.4s + fmin v30.4s, v30.4s, v16.4s + KAI_ASM_INST(0x0ea16bf3) // bfcvtn v19.4h, v31.4s + fmin v29.4s, v29.4s, v16.4s + fmin v28.4s, v28.4s, v16.4s + KAI_ASM_INST(0x0ea16bd2) // bfcvtn v18.4h, v30.4s + KAI_ASM_INST(0x0ea16bb1) // bfcvtn v17.4h, v29.4s + KAI_ASM_INST(0x0ea16b90) // bfcvtn v16.4h, v28.4s + blt label_21 + mov x20, x14 + cmp x15, #0x1 + str d19, [x20, #0x0] + add x20, x20, x13 + ble label_24 + cmp x15, #0x2 + str d18, [x20, #0x0] + add x20, x20, x13 + ble label_24 + cmp x15, #0x3 + str d17, [x20, #0x0] + add x20, x20, x13 + ble label_24 + str d16, [x20, #0x0] + b label_24 +KAI_ASM_LABEL(label_21) // Row tail: Partial output + mov x23, x14 + cmp x15, #0x1 + add x22, x23, x13 + csel x22, x22, x23, GT + cmp x15, #0x2 + add x21, x23, x13, LSL #1 + csel x21, x21, x22, GT + cmp x15, #0x3 + add x20, x21, x13 + csel x20, x20, x21, GT + tbz x25, #1, label_22 + st1 { v16.s }[0], [x20], #0x4 + st1 { v17.s }[0], [x21], #0x4 + st1 { v18.s }[0], [x22], #0x4 + st1 { v19.s }[0], [x23], #0x4 + tbz x25, #0, label_23 + st1 { v16.h }[2], [x20] + st1 { v17.h }[2], [x21] + st1 { v18.h }[2], [x22] + st1 { v19.h }[2], [x23] + b label_23 +KAI_ASM_LABEL(label_22) // Row tail: Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] + st1 { v17.h }[0], [x21] + st1 { v18.h }[0], [x22] + st1 { v19.h }[0], [x23] +KAI_ASM_LABEL(label_23) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_24) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x14, x14, #0x8 + bgt label_17 + subs x15, x15, #0x4 + add x8, x8, x5 + mov x14, x24 + bgt label_16 +KAI_ASM_LABEL(label_25) // Row tail: Row loop skip + add SP, SP, #0x100 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h new file mode 100644 index 00000000..fdc7f6e7 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h @@ -0,0 +1,53 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_bf16_qai8dxp_qsi4c32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_rhs_packed_offset_func_t)( + size_t n_idx, size_t k, size_t bl); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_p, const void* rhs_p, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_ukernel { + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_mr_func_t get_mr; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_nr_func_t get_nr; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_kr_func_t get_kr; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_sr_func_t get_sr; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 9eba8b95..19c32d48 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -370,6 +370,14 @@ template Buffer matmul_nt_t_quantized( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, + size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, + size_t rhs_quant_width, // + const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); + template Buffer matmul_nt_t_quantized( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp new file mode 100644 index 00000000..31acc6db --- /dev/null +++ b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp @@ -0,0 +1,448 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" +#include "test/common/bfloat16.hpp" +#include "test/common/buffer.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/int4.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/pad.hpp" +#include "test/reference/quantize.hpp" +#include "test/reference/transpose.hpp" + +namespace kai::test { + +static const std::array, 2> + variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p = {{ + {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod), + "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm), + "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm}, + }}; + +using MatMulTestParams_withBL = std::tuple; + +class MatMulTest_bf16_qai8dxp_qsi4c32p : public ::testing::TestWithParam {}; + +TEST_P(MatMulTest_bf16_qai8dxp_qsi4c32p, EndToEnd_RHS_NxK) { + const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const std::uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + // Generates input data. + const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + const auto ref_biases = fill_random(N, seed + 2); + kai_datatype scale_dt = kai_datatype::kai_dt_bf16; + + // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul + // implementation works with FP32 accumulation and casts the result to BFP16 + const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); + + const auto ref_dst_no_clamp = + matmul_nt_t_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, + ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, 1); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), M * N, clamp_ratio); + const auto ref_dst_float = clamp(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to BF16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); + Buffer imp_packed_lhs = Buffer(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(uint16_t); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_bf16_neon( + rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, + reinterpret_cast(imp_packed_lhs.data()) + lhs_packed_offset); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const auto ref_rhs_qsu4_padded = pad_row( + ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); + + const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); + Buffer imp_packed_rhs(imp_packed_rhs_size); + + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rhs_start_row, K, nr, kr, sr, bl, scale_dt); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + auto rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rhs_start_row, ref_rhs_qsu4_stride); + size_t bias_offset = rhs_start_row * sizeof(float); + size_t scale_offset = rhs_start_row * ref_rhs_scales_stride; + + kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_datatype::kai_dt_bf16; + + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, rect.width() /* n */, K, nr, kr, sr, bl, + reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset), ref_rhs_qsu4_stride, + reinterpret_cast(ref_biases.data() + bias_offset), + reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, + imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + Buffer imp_dst = Buffer(imp_dst_size); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, reinterpret_cast(imp_packed_lhs.data()) + lhs_matmul_offset, + reinterpret_cast(imp_packed_rhs.data()) + rhs_matmul_offset, + reinterpret_cast(imp_dst.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::BF16); + const auto success = + compare(reinterpret_cast(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + + if (kr / sr == 8) { + // Test that vectorized packing kernel for nrx8 gives same output as scalar + const auto imp_packed_rhs_size_neon = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); + + Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); + + auto rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + rhs_start_row, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); + + auto rhs_offset_neon = + kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rhs_start_row, ref_rhs_qsu4_stride); + + kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( + 1, rect.width() /* n */, K, nr, kr, sr, bl, + reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset_neon), ref_rhs_qsu4_stride, + reinterpret_cast(ref_biases.data() + bias_offset), + reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, + imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs_neon.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, + clamp_min, clamp_max); + + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + } else if (kr / sr == 4) { + // Test that vectorized packing kernel for nrx4 gives same output as scalar + const auto imp_packed_rhs_size_neon = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); + + Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); + + auto rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + rhs_start_row, K, nr, kr, sr, bl, scale_dt); + ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); + + auto rhs_offset_neon = + kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rhs_start_row, ref_rhs_qsu4_stride); + + kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( + 1, rect.width() /* n */, K, nr, kr, sr, bl, + reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset_neon), ref_rhs_qsu4_stride, + reinterpret_cast(ref_biases.data() + bias_offset), + reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, + imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs_neon.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, + clamp_min, clamp_max); + + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + } +} + +TEST_P(MatMulTest_bf16_qai8dxp_qsi4c32p, EndToEnd_RHS_KxN) { + const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const std::uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + // Generates input data. + const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + const auto ref_biases = fill_random(N, seed + 2); + kai_datatype scale_dt = kai_datatype::kai_dt_bf16; + + // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul + // implementation works with FP32 accumulation and casts the result to BFP16 + const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); + + // Transposed(nxk) RHS dimensions + const size_t ref_rhs_qsi4_nxk_stride = K; + + // Non-Transposed(kxn) RHS dimensions + const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); + const size_t ref_rhs_qsi4_kxn_size = K * ref_rhs_qsi4_kxn_stride; + const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(ref_rhs_qsi4_kxn_size, 2); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = + quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); + + auto ref_rhs_qsi4 = transpose_with_padding( + ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, + ref_rhs_qsi4_kxn_size_bytes); + + const auto ref_dst_clamp_f32 = + matmul_clamp_nt_nt( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, ref_biases.data(), std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_clamp_f32.data(), M * N, clamp_ratio); + const auto ref_dst_float = clamp(ref_dst_clamp_f32.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to BF16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); + Buffer imp_packed_lhs = Buffer(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(uint16_t); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_bf16_neon( + rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, + reinterpret_cast(imp_packed_lhs.data()) + lhs_packed_offset); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), ref_rhs_qsi4_kxn_size); + const auto ref_rhs_qsu4_padded = pad_row( + ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); + const size_t ref_rhs_qsu4_stride = round_up_division(N, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); + Buffer imp_packed_rhs(imp_packed_rhs_size); + + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rhs_start_row, K, nr, kr, sr, bl, scale_dt); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + auto rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rhs_start_row, ref_rhs_qsu4_stride); + size_t bias_offset = rhs_start_row * sizeof(float); + size_t scale_offset = rhs_start_row * ref_rhs_scales_stride; + + kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_datatype::kai_dt_bf16; + + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + 1, rect.width() /* n */, K, nr, kr, sr, bl, + reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset), ref_rhs_qsu4_stride, + reinterpret_cast(ref_biases.data() + bias_offset), + reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, + imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + Buffer imp_dst = Buffer(imp_dst_size); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, reinterpret_cast(imp_packed_lhs.data()) + lhs_matmul_offset, + reinterpret_cast(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset, + dst_stride_row, dst_stride_col, clamp_min, clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::BF16); + const auto success = + compare(reinterpret_cast(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_bf16_qai8dxp_qsi4c32p, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), + testing::Values( + MatMulShape{16, 32, 64}, // + MatMulShape{8, 32, 128}, // + MatMulShape{17, 25, 64}, // + MatMulShape{15, 31, 128}, // + MatMulShape{1, 25, 64}), + testing::Values(32, 64), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion. + MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5f, 1, 0.8f) // Somewhere Middle + )), + [](const auto& info) { + const auto variant_idx = std::get<0>(info.param); + const std::string name{variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_idx).name}; + const auto shape = std::get(info.param); + const auto bl = std::get<2>(info.param); + const auto portion = std::get<3>(info.param); + + std::ostringstream sstream; + sstream << name << "__"; + PrintTo(shape, &sstream); + sstream << "__BL_" << bl << "__"; + PrintTo(portion, &sstream); + + return sstream.str(); + }); + +} // namespace kai::test -- GitLab From 906bf80d55293cb54a77b3870be637ee5e6108c4 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 21 Jul 2025 11:19:22 +0100 Subject: [PATCH 2/5] combine unit test with similar f32 one Signed-off-by: Evie Wright --- CMakeLists.txt | 6 +- test/reference/matmul.cpp | 8 - ...atmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp | 448 ------------------ ...=> matmul_clamp_qai8dxp_qsi4c32p_test.cpp} | 191 +++++++- 4 files changed, 191 insertions(+), 462 deletions(-) delete mode 100644 test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp rename test/tests/{matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp => matmul_clamp_qai8dxp_qsi4c32p_test.cpp} (67%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f1209371..86e5cf50 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -462,10 +462,9 @@ if(KLEIDIAI_BUILD_TESTS) add_executable(kleidiai_test test/tests/bfloat16_test.cpp test/tests/float16_test.cpp - test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp - test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp + test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp ) else() add_executable(kleidiai_test @@ -474,11 +473,9 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/float16_test.cpp test/tests/imatmul_test.cpp test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp - test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp - test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -487,6 +484,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp + test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_test.cpp ) diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 19c32d48..9eba8b95 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -370,14 +370,6 @@ template Buffer matmul_nt_t_quantized( - size_t m, size_t n, size_t k, // - const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, - size_t lhs_quant_width, // - const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, - size_t rhs_quant_width, // - const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); - template Buffer matmul_nt_t_quantized( size_t m, size_t n, size_t k, // const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp deleted file mode 100644 index 31acc6db..00000000 --- a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4c32p_test.cpp +++ /dev/null @@ -1,448 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h" -#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h" -#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" -#include "test/common/bfloat16.hpp" -#include "test/common/buffer.hpp" -#include "test/common/compare.hpp" -#include "test/common/cpu_info.hpp" -#include "test/common/data_format.hpp" -#include "test/common/int4.hpp" -#include "test/common/matmul_test_common.hpp" -#include "test/common/matrix_portion.hpp" -#include "test/common/memory.hpp" -#include "test/common/round.hpp" -#include "test/common/test_suite.hpp" -#include "test/reference/cast.hpp" -#include "test/reference/clamp.hpp" -#include "test/reference/fill.hpp" -#include "test/reference/matmul.hpp" -#include "test/reference/pack.hpp" -#include "test/reference/pad.hpp" -#include "test/reference/quantize.hpp" -#include "test/reference/transpose.hpp" - -namespace kai::test { - -static const std::array, 2> - variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p = {{ - {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod), - "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod}, - {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm), - "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm}, - }}; - -using MatMulTestParams_withBL = std::tuple; - -class MatMulTest_bf16_qai8dxp_qsi4c32p : public ::testing::TestWithParam {}; - -TEST_P(MatMulTest_bf16_qai8dxp_qsi4c32p, EndToEnd_RHS_NxK) { - const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); - const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index); - - if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { - GTEST_SKIP() << "CPU features are not supported by current CPU"; - } - - const std::uint32_t seed = 0; - - const size_t M = matmul_shape.m; - const size_t N = matmul_shape.n; - const size_t K = matmul_shape.k; - - const auto mr = ukernel_variant.interface.get_mr(); - const auto nr = ukernel_variant.interface.get_nr(); - const auto kr = ukernel_variant.interface.get_kr(); - const auto sr = ukernel_variant.interface.get_sr(); - - auto m_step = ukernel_variant.interface.get_m_step(); - ASSERT_TRUE(m_step % mr == 0); - - auto n_step = ukernel_variant.interface.get_n_step(); - ASSERT_TRUE(n_step % nr == 0); - - const auto rect = portion.compute_portion(M, N, m_step, n_step); - if (rect.height() == 0 || rect.width() == 0) { - GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; - } - - // Generates input data. - const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); - const auto ref_rhs = fill_random(N * K, seed + 1); - const auto ref_biases = fill_random(N, seed + 2); - kai_datatype scale_dt = kai_datatype::kai_dt_bf16; - - // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul - // implementation works with FP32 accumulation and casts the result to BFP16 - const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); - - // Runs the reference implementation. - // * Quantizes the LHS matrix using 8-bit symmetric quantization. - // * Quantizes the RHS matrix using 8-bit asymmetric quantization. - // * Performs GEMM. - const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = - quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); - const auto [ref_rhs_qsi4, ref_rhs_scales] = - quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); - - const auto ref_dst_no_clamp = - matmul_nt_t_quantized( - M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, - ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, 1); - - // Clamps the reference output. - const auto clamp_ratio = 0.8F; - const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), M * N, clamp_ratio); - const auto ref_dst_float = clamp(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); - - // Cast the reference output to BF16 - auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); - - // Runs the LHS packing micro-kernel. - const auto lhs_start_row = rect.start_row(); - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); - Buffer imp_packed_lhs = Buffer(imp_packed_lhs_size); - - auto lhs_stride = K * sizeof(uint16_t); - - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); - auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); - auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); - - ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - - kai_run_lhs_quant_pack_qai8dxp_bf16_neon( - rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, - reinterpret_cast(imp_packed_lhs.data()) + lhs_packed_offset); - - // Runs the RHS packing micro-kernel. - // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. - // * Packs the RHS matrix. - const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); - const auto ref_rhs_qsu4_padded = pad_row( - ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); - - const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); - const size_t ref_rhs_scales_stride = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); - - const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); - Buffer imp_packed_rhs(imp_packed_rhs_size); - - const auto rhs_start_row = rect.start_col(); - auto rhs_packed_offset = - kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rhs_start_row, K, nr, kr, sr, bl, scale_dt); - auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); - ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); - - auto rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rhs_start_row, ref_rhs_qsu4_stride); - size_t bias_offset = rhs_start_row * sizeof(float); - size_t scale_offset = rhs_start_row * ref_rhs_scales_stride; - - kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_datatype::kai_dt_bf16; - - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - 1, rect.width() /* n */, K, nr, kr, sr, bl, - reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset), ref_rhs_qsu4_stride, - reinterpret_cast(ref_biases.data() + bias_offset), - reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, - imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); - - const auto dst_stride_row = N * sizeof(uint16_t); - const auto dst_stride_col = sizeof(uint16_t); - const auto dst_offset = - ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); - const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; - ASSERT_EQ(dst_offset, ref_dst_offset); - - // Runs the GEMM micro-kernel. - const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); - ASSERT_EQ(imp_dst_size, ref_dst.size()); - Buffer imp_dst = Buffer(imp_dst_size); - - ukernel_variant.interface.run_matmul( - rect.height(), rect.width(), K, bl, reinterpret_cast(imp_packed_lhs.data()) + lhs_matmul_offset, - reinterpret_cast(imp_packed_rhs.data()) + rhs_matmul_offset, - reinterpret_cast(imp_dst.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, clamp_max); - - // Compares the output of the micro-kernels against the output of the reference implementation for the portion - // tested. - DefaultMismatchHandler handler(0, 0.02, 0, 0.05); - DataFormat dst_format = DataFormat(DataType::BF16); - const auto success = - compare(reinterpret_cast(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); - ASSERT_TRUE(success); - - if (kr / sr == 8) { - // Test that vectorized packing kernel for nrx8 gives same output as scalar - const auto imp_packed_rhs_size_neon = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); - ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); - - Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); - - auto rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( - rhs_start_row, K, nr, kr, sr, bl, scale_dt); - ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); - - auto rhs_offset_neon = - kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rhs_start_row, ref_rhs_qsu4_stride); - - kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( - 1, rect.width() /* n */, K, nr, kr, sr, bl, - reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset_neon), ref_rhs_qsu4_stride, - reinterpret_cast(ref_biases.data() + bias_offset), - reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, - imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); - - ukernel_variant.interface.run_matmul( - rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, - imp_packed_rhs_neon.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, - clamp_min, clamp_max); - - const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); - ASSERT_TRUE(success); - } else if (kr / sr == 4) { - // Test that vectorized packing kernel for nrx4 gives same output as scalar - const auto imp_packed_rhs_size_neon = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); - ASSERT_EQ(imp_packed_rhs_size_neon, imp_packed_rhs_size); - - Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); - - auto rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( - rhs_start_row, K, nr, kr, sr, bl, scale_dt); - ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); - - auto rhs_offset_neon = - kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rhs_start_row, ref_rhs_qsu4_stride); - - kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( - 1, rect.width() /* n */, K, nr, kr, sr, bl, - reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset_neon), ref_rhs_qsu4_stride, - reinterpret_cast(ref_biases.data() + bias_offset), - reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, - imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); - - ukernel_variant.interface.run_matmul( - rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, - imp_packed_rhs_neon.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, - clamp_min, clamp_max); - - const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); - ASSERT_TRUE(success); - } -} - -TEST_P(MatMulTest_bf16_qai8dxp_qsi4c32p, EndToEnd_RHS_KxN) { - const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); - const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index); - - if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { - GTEST_SKIP() << "CPU features are not supported by current CPU"; - } - - const std::uint32_t seed = 0; - - const size_t M = matmul_shape.m; - const size_t N = matmul_shape.n; - const size_t K = matmul_shape.k; - - const auto mr = ukernel_variant.interface.get_mr(); - const auto nr = ukernel_variant.interface.get_nr(); - const auto kr = ukernel_variant.interface.get_kr(); - const auto sr = ukernel_variant.interface.get_sr(); - - auto m_step = ukernel_variant.interface.get_m_step(); - ASSERT_TRUE(m_step % mr == 0); - - auto n_step = ukernel_variant.interface.get_n_step(); - ASSERT_TRUE(n_step % nr == 0); - - const auto rect = portion.compute_portion(M, N, m_step, n_step); - if (rect.height() == 0 || rect.width() == 0) { - GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; - } - - // Generates input data. - const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); - const auto ref_rhs = fill_random(N * K, seed + 1); - const auto ref_biases = fill_random(N, seed + 2); - kai_datatype scale_dt = kai_datatype::kai_dt_bf16; - - // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul - // implementation works with FP32 accumulation and casts the result to BFP16 - const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); - - // Transposed(nxk) RHS dimensions - const size_t ref_rhs_qsi4_nxk_stride = K; - - // Non-Transposed(kxn) RHS dimensions - const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); - const size_t ref_rhs_qsi4_kxn_size = K * ref_rhs_qsi4_kxn_stride; - const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(ref_rhs_qsi4_kxn_size, 2); - - // Runs the reference implementation. - // * Quantizes the LHS matrix using 8-bit symmetric quantization. - // * Quantizes the RHS matrix using 8-bit asymmetric quantization. - // * Performs GEMM. - const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = - quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); - const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = - quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, bl); - - auto ref_rhs_qsi4 = transpose_with_padding( - ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, - ref_rhs_qsi4_kxn_size_bytes); - - const auto ref_dst_clamp_f32 = - matmul_clamp_nt_nt( - M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), - ref_rhs_scales.data(), nullptr, bl, ref_biases.data(), std::numeric_limits::lowest(), - std::numeric_limits::max()); - - // Clamps the reference output. - const auto clamp_ratio = 0.8F; - const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_clamp_f32.data(), M * N, clamp_ratio); - const auto ref_dst_float = clamp(ref_dst_clamp_f32.data(), M * N, clamp_min, clamp_max); - - // Cast the reference output to BF16 - auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); - - // Runs the LHS packing micro-kernel. - const auto lhs_start_row = rect.start_row(); - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); - Buffer imp_packed_lhs = Buffer(imp_packed_lhs_size); - - auto lhs_stride = K * sizeof(uint16_t); - - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); - auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); - auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); - - ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - - kai_run_lhs_quant_pack_qai8dxp_bf16_neon( - rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, - reinterpret_cast(imp_packed_lhs.data()) + lhs_packed_offset); - - // Runs the RHS packing micro-kernel. - // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. - // * Packs the RHS matrix. - const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), ref_rhs_qsi4_kxn_size); - const auto ref_rhs_qsu4_padded = pad_row( - ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); - const size_t ref_rhs_qsu4_stride = round_up_division(N, 2); - const size_t ref_rhs_scales_stride = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); - - const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); - Buffer imp_packed_rhs(imp_packed_rhs_size); - - const auto rhs_start_row = rect.start_col(); - auto rhs_packed_offset = - kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rhs_start_row, K, nr, kr, sr, bl, scale_dt); - auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); - ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); - - auto rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rhs_start_row, ref_rhs_qsu4_stride); - size_t bias_offset = rhs_start_row * sizeof(float); - size_t scale_offset = rhs_start_row * ref_rhs_scales_stride; - - kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{}; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_datatype::kai_dt_bf16; - - kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( - 1, rect.width() /* n */, K, nr, kr, sr, bl, - reinterpret_cast(ref_rhs_qsu4_padded.data() + rhs_offset), ref_rhs_qsu4_stride, - reinterpret_cast(ref_biases.data() + bias_offset), - reinterpret_cast(ref_rhs_scales.data() + scale_offset), ref_rhs_scales_stride, - imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); - - const auto dst_stride_row = N * sizeof(uint16_t); - const auto dst_stride_col = sizeof(uint16_t); - const auto dst_offset = - ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); - const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; - ASSERT_EQ(dst_offset, ref_dst_offset); - - // Runs the GEMM micro-kernel. - const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); - ASSERT_EQ(imp_dst_size, ref_dst.size()); - Buffer imp_dst = Buffer(imp_dst_size); - - ukernel_variant.interface.run_matmul( - rect.height(), rect.width(), K, bl, reinterpret_cast(imp_packed_lhs.data()) + lhs_matmul_offset, - reinterpret_cast(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset, - dst_stride_row, dst_stride_col, clamp_min, clamp_max); - - // Compares the output of the micro-kernels against the output of the reference implementation for the portion - // tested. - DefaultMismatchHandler handler(0, 0.02, 0, 0.05); - DataFormat dst_format = DataFormat(DataType::BF16); - const auto success = - compare(reinterpret_cast(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); - ASSERT_TRUE(success); -} - -INSTANTIATE_TEST_SUITE_P( - MatMul, MatMulTest_bf16_qai8dxp_qsi4c32p, - testing::Combine( - testing::Range(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), - testing::Values( - MatMulShape{16, 32, 64}, // - MatMulShape{8, 32, 128}, // - MatMulShape{17, 25, 64}, // - MatMulShape{15, 31, 128}, // - MatMulShape{1, 25, 64}), - testing::Values(32, 64), - testing::Values( - MatrixPortion(0, 0, 1, 1), // Full matrix. - MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion. - MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion. - MatrixPortion(0, 0.5f, 1, 0.8f) // Somewhere Middle - )), - [](const auto& info) { - const auto variant_idx = std::get<0>(info.param); - const std::string name{variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_idx).name}; - const auto shape = std::get(info.param); - const auto bl = std::get<2>(info.param); - const auto portion = std::get<3>(info.param); - - std::ostringstream sstream; - sstream << name << "__"; - PrintTo(shape, &sstream); - sstream << "__BL_" << bl << "__"; - PrintTo(portion, &sstream); - - return sstream.str(); - }); - -} // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp similarity index 67% rename from test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp rename to test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp index f753b0b8..ab029abb 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp @@ -18,6 +18,9 @@ #include #include "kai/kai_common.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" @@ -30,6 +33,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" @@ -82,6 +86,14 @@ static const std::array, 2> + variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p = {{ + {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod), + "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm), + "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm}, + }}; + // Executes the scalar RHS packing micro-kernel. static inline std::tuple pack_rhs_qsi4c32pscalebf16( size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, const Buffer& rhs_values_qsi4, const Buffer& biases, @@ -206,6 +218,8 @@ using MatMulTestParams_withBL_withRHSPackType = std::tuple {}; +class MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p + : public ::testing::TestWithParam {}; TEST_P(MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) { auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type] = GetParam(); @@ -278,8 +292,8 @@ TEST_P(MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) { const auto lhs_stride = K * sizeof(float); - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); - auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); @@ -339,6 +353,145 @@ TEST_P(MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) { } } +TEST_P(MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, EndToEnd) { + auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type] = GetParam(); + auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "Unsupported CPU feature"; + } + + const uint32_t seed = 0; + + size_t M = matmul_shape.m; + size_t N = matmul_shape.n; + size_t K = matmul_shape.k; + + auto mr = ukernel_variant.interface.get_mr(); + auto nr = ukernel_variant.interface.get_nr(); + auto kr = ukernel_variant.interface.get_kr(); + auto sr = ukernel_variant.interface.get_sr(); + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + // Generates input data. + const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + const auto ref_biases = fill_random(N, seed + 2); + + // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BF16 because the matmul + // implementation works with FP32 accumulation and casts the result to BF16 + const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + auto [ref_rhs_values_qsi4, ref_rhs_scales] = + quantize_rhs_qsi4c32p(N, K, bl, ref_rhs, rhs_pack_type == RhsPackType::NxK); + + Buffer ref_dst_noclamp; + if (rhs_pack_type == RhsPackType::NxK) { + ref_dst_noclamp = + matmul_nt_t_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, + ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, + 1); + } else { + ref_dst_noclamp = + matmul_nt_nt_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, + ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, + 1); + } + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_noclamp.data(), M * N, clamp_ratio); + auto ref_dst_float = clamp(ref_dst_noclamp.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to BF16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); + Buffer imp_packed_lhs(imp_packed_lhs_size); + + const auto lhs_stride = K * sizeof(uint16_t); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_bf16_neon( + rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, + reinterpret_cast(imp_packed_lhs.data()) + lhs_packed_offset); + + const auto rhs_start_row = rect.start_col(); + size_t bias_offset = rhs_start_row * sizeof(float); + + auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16( + N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type, + rhs_start_row, rect.width()); + + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + Buffer imp_dst(imp_dst_size); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, reinterpret_cast(imp_packed_lhs.data()) + lhs_matmul_offset, + reinterpret_cast(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset, + dst_stride_row, dst_stride_col, clamp_min, clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::BF16); + const auto success = + compare(reinterpret_cast(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + + // Test vectorized packing functions, if packing parameters allow + if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) { + const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon( + N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type, + rhs_start_row, rect.width()); + ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), + dst_stride_row, dst_stride_col, clamp_min, clamp_max); + + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); + } +} + INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, testing::Combine( @@ -373,4 +526,38 @@ INSTANTIATE_TEST_SUITE_P( return sstream.str(); }); +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), + testing::Values( + MatMulShape{16, 32, 64}, // + MatMulShape{8, 32, 128}, // + MatMulShape{17, 25, 64}, // + MatMulShape{15, 31, 128}, // + MatMulShape{1, 25, 64}), + testing::Values(32, 64), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion. + MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5f, 1, 0.8f)), // Somewhere Middle + testing::Values(RhsPackType::NxK, RhsPackType::KxN)), + [](const auto& info) { + const auto variant_idx = std::get<0>(info.param); + const std::string name{variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_idx).name}; + const auto shape = std::get(info.param); + const auto bl = std::get<2>(info.param); + const auto portion = std::get<3>(info.param); + const RhsPackType rhs_pack_type = std::get<4>(info.param); + + std::ostringstream sstream; + sstream << name << ((rhs_pack_type == RhsPackType::NxK) ? "__NxK" : "__KxN") << "__"; + PrintTo(shape, &sstream); + sstream << "__BL_" << bl << "__"; + PrintTo(portion, &sstream); + + return sstream.str(); + }); + } // namespace kai::test -- GitLab From 65a07a6e3220ce844805abdeb55d2a1d953db17f Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 21 Jul 2025 15:01:26 +0100 Subject: [PATCH 3/5] address comments: small fixes and refactoring Signed-off-by: Evie Wright --- .../matmul_clamp_qai8dxp_qsi4c32p_test.cpp | 49 +++++++------------ 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp index ab029abb..004f74f2 100644 --- a/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp @@ -94,6 +94,19 @@ static const std::array pack_rhs_qsi4c32pscalebf16( size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, const Buffer& rhs_values_qsi4, const Buffer& biases, @@ -292,8 +305,8 @@ TEST_P(MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) { const auto lhs_stride = K * sizeof(float); - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); - auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); @@ -495,20 +508,8 @@ TEST_P(MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, EndToEnd) { INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, testing::Combine( - testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), - testing::Values( - MatMulShape{16, 32, 64}, // - MatMulShape{8, 32, 128}, // - MatMulShape{17, 25, 64}, // - MatMulShape{15, 31, 128}, // - MatMulShape{1, 25, 64}), - testing::Values(32, 64), - testing::Values( - MatrixPortion(0, 0, 1, 1), // Full matrix. - MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion. - MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion. - MatrixPortion(0, 0.5f, 1, 0.8f)), // Somewhere Middle - testing::Values(RhsPackType::NxK, RhsPackType::KxN)), + testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), test_matmul_shapes, + test_block_lengths, test_portions, testing::Values(RhsPackType::NxK, RhsPackType::KxN)), [](const auto& info) { const auto variant_idx = std::get<0>(info.param); const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_idx).name}; @@ -529,20 +530,8 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, testing::Combine( - testing::Range(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), - testing::Values( - MatMulShape{16, 32, 64}, // - MatMulShape{8, 32, 128}, // - MatMulShape{17, 25, 64}, // - MatMulShape{15, 31, 128}, // - MatMulShape{1, 25, 64}), - testing::Values(32, 64), - testing::Values( - MatrixPortion(0, 0, 1, 1), // Full matrix. - MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion. - MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion. - MatrixPortion(0, 0.5f, 1, 0.8f)), // Somewhere Middle - testing::Values(RhsPackType::NxK, RhsPackType::KxN)), + testing::Range(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), test_matmul_shapes, + test_block_lengths, test_portions, testing::Values(RhsPackType::NxK, RhsPackType::KxN)), [](const auto& info) { const auto variant_idx = std::get<0>(info.param); const std::string name{variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_idx).name}; -- GitLab From 231936034a0c1618afef7198450d0997cb5ee6f2 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 21 Jul 2025 15:53:23 +0100 Subject: [PATCH 4/5] test with larger matrix size, add clear assumptions on parameter values Signed-off-by: Evie Wright --- test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp index 004f74f2..a02cab2e 100644 --- a/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp @@ -95,8 +95,8 @@ static const std::array Date: Mon, 21 Jul 2025 15:57:26 +0100 Subject: [PATCH 5/5] format test shapes one per line Signed-off-by: Evie Wright --- test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp index a02cab2e..7a5ef9e8 100644 --- a/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp @@ -95,8 +95,14 @@ static const std::array