From a797051780b1fdf3d94c0e8d123356f6c56b574f Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Thu, 24 Apr 2025 13:56:49 +0000 Subject: [PATCH 1/9] Matmul Micro-kernels BF16 <- (QAI8DXP) LHS x (QSI4CXP) RHS - Matrix multiplication (MxN) Micro-kernels of QAI8DXP LHS and QSI4CXP RHS with BF16 output, optimized for FEAT_I8MM. - Matrix multiplication (1xN) Micro-kernels of QAI8DXP LHS and QSI4CXP RHS with BF16 output, optimized for FEAT_DotProd. Signed-off-by: Nikhil Gupta --- CHANGELOG.md | 3 + CMakeLists.txt | 54 +- kai/ukernels/matmul/BUILD.bazel | 27 + ...6_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c | 165 ++++ ...6_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h | 139 ++++ ...i8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S | 195 +++++ ...bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c | 165 ++++ ...bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h | 139 ++++ ..._qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S | 733 ++++++++++++++++++ ...mul_clamp_bf16_qai8dxp_qsi4cxp_interface.h | 52 ++ .../kai_lhs_quant_pack_qai8dxp_bf16_neon.c | 254 ++++++ .../kai_lhs_quant_pack_qai8dxp_bf16_neon.h | 76 ++ test/common/cpu_info.cpp | 8 + test/common/cpu_info.hpp | 6 + test/reference/cast.cpp | 1 + test/reference/fill.cpp | 1 + ...matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp | 358 +++++++++ 17 files changed, 2368 insertions(+), 8 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h create mode 100644 test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fb3388d..03ebd525 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_I8MM. +- Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_DotProd. + ## v1.11.0 - New Advanced SIMD micro-kernels: diff --git a/CMakeLists.txt b/CMakeLists.txt index e4597460..fdc4183a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -139,7 +139,12 @@ set(KLEIDIAI_FILES_NEON_FP16_I8MM ${KLEIDIAI_FILES_NEON_FP16_I8MM_ASM} ) +set(KLEIDIAI_FILES_NEON_BF16_ASM + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c +) + set(KLEIDIAI_FILES_NEON_BF16 + ${KLEIDIAI_FILES_NEON_BF16_ASM} kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c @@ -312,6 +317,24 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) +set(KLEIDIAI_FILES_NEON_BF16_DOTPROD_ASM + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S +) + +set(KLEIDIAI_FILES_NEON_BF16_I8MM_ASM + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S +) + +set(KLEIDIAI_FILES_NEON_BF16_DOTPROD + ${KLEIDIAI_FILES_NEON_BF16_DOTPROD_ASM} +) + +set(KLEIDIAI_FILES_NEON_BF16_I8MM + ${KLEIDIAI_FILES_NEON_BF16_I8MM_ASM} +) + add_library(kleidiai) add_library(${PROJECT_NAME}::kleidiai ALIAS kleidiai) @@ -329,6 +352,9 @@ if(NOT MSVC) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_DOTPROD}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_I8MM}) + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) @@ -339,6 +365,8 @@ if(NOT MSVC) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) # Use -fno-tree-vectorize option to disable compiler based vectorization set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") @@ -349,20 +377,28 @@ else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) - - set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_DOTPROD}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_I8MM}) + target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_ASM}) + + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_DOTPROD} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_I8MM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_SME_ASM} ${KLEIDIAI_FILES_SME2_ASM} ${KLEIDIAI_FILES_NEON_ASM} ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} - ${KLEIDIAI_FILES_NEON_I8MM_ASM}) + ${KLEIDIAI_FILES_NEON_I8MM_ASM} + ${KLEIDIAI_FILES_NEON_BF16_DOTPROD_ASM} + ${KLEIDIAI_FILES_NEON_BF16_I8MM_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") set_source_files_properties(${KLEIDIAI_FILES_ASM} PROPERTIES LANGUAGE ASM_MARMASM) @@ -448,6 +484,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/float16_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp + test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp ) else() add_executable(kleidiai_test @@ -455,6 +492,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/buffer_test.cpp test/tests/float16_test.cpp test/tests/imatmul_test.cpp + test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index ae9628a3..b961b9e7 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -62,6 +62,7 @@ BF16_KERNELS = [ "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla", "pack/kai_lhs_quant_pack_bf16p1x4_f32_neon", "pack/kai_lhs_quant_pack_bf16p8x4_f32_neon", + "pack/kai_lhs_quant_pack_qai8dxp_bf16_neon", "pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon", ] @@ -142,6 +143,16 @@ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] +# buildifier: keep sorted +DOTPROD_BF16_KERNELS_ASM = [ + "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", +] + +# buildifier: keep sorted +I8MM_BF16_KERNELS_ASM = [ + "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", +] + # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", @@ -311,17 +322,33 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], ) +kai_c_library( + name = "i8mm_bf16_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in I8MM_BF16_KERNELS_ASM] + [ukernel + ".c" for ukernel in I8MM_BF16_KERNELS_ASM], + cpu_uarch = kai_cpu_i8mm() + kai_cpu_bf16(), + textual_hdrs = [ukernel + ".h" for ukernel in I8MM_BF16_KERNELS_ASM], +) + +kai_c_library( + name = "dotprod_bf16_impl_asm", + srcs = [ukernel + "_asm.S" for ukernel in DOTPROD_BF16_KERNELS_ASM] + [ukernel + ".c" for ukernel in DOTPROD_BF16_KERNELS_ASM], + cpu_uarch = kai_cpu_dotprod() + kai_cpu_bf16(), + textual_hdrs = [ukernel + ".h" for ukernel in DOTPROD_BF16_KERNELS_ASM], +) + kai_c_library( name = "matmul", visibility = ["//visibility:public"], deps = [ ":bf16_impl", + ":dotprod_bf16_impl_asm", ":dotprod_impl", ":dotprod_impl_asm", ":fp16_bf16_impl", ":fp16_dotprod_impl_asm", ":fp16_i8mm_impl_asm", ":fp16_impl", + ":i8mm_bf16_impl_asm", ":i8mm_impl", ":i8mm_impl_asm", ":interface", diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c new file mode 100644 index 00000000..8b356de7 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "Dotprod extension and bf16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 8; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + if (m == 0) { + return; + } + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h new file mode 100644 index 00000000..e0688423 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_bf16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S new file mode 100644 index 00000000..2f4f66ef --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S @@ -0,0 +1,195 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x13, #0x20 + movi v5.16b, #0xf0 + mov x21, #0x8 + ldr x12, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x11, [x0, #0x8] + ldr x10, [x0, #0x10] + ldr x9, [x0, #0x30] + ldr x28, [x0, #0x0] + ldr x27, [x0, #0x20] + madd x13, x12, x13, x21 + ldr x26, [x0, #0x18] + mov x25, x20 +KAI_ASM_LABEL(label_1) // Row loop + mov x24, x10 + mov x23, x9 + add x22, x28, x27 +KAI_ASM_LABEL(label_2) // Column loop + mov x21, x11 + movi v4.4s, #0x0 + movi v3.4s, #0x0 + mov x20, x12 + movi v2.4s, #0x0 + movi v1.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q0, [x24, #0x0] + ldr q31, [x24, #0x10] + subs x20, x20, #0x1 + ldr q30, [x24, #0x20] + ldr q29, [x24, #0x30] + ld1r { v28.2d }, [x21], #0x8 + ldr q27, [x24, #0x40] + ldr q26, [x24, #0x50] + ldr q25, [x24, #0x60] + shl v24.16b, v0.16b, #0x4 + shl v18.16b, v31.16b, #0x4 + ldr q23, [x24, #0x70] + shl v17.16b, v30.16b, #0x4 + shl v16.16b, v29.16b, #0x4 + add x24, x24, #0x80 + ld1r { v22.2d }, [x21], #0x8 + shl v21.16b, v27.16b, #0x4 + and v0.16b, v0.16b, v5.16b + ld1r { v20.2d }, [x21], #0x8 + ld1r { v19.2d }, [x21], #0x8 + KAI_ASM_INST(0x4e9c9704) // sdot v4.4s, v24.16b, v28.16b + KAI_ASM_INST(0x4e9c9643) // sdot v3.4s, v18.16b, v28.16b + shl v18.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e9c9622) // sdot v2.4s, v17.16b, v28.16b + KAI_ASM_INST(0x4e9c9601) // sdot v1.4s, v16.16b, v28.16b + shl v17.16b, v25.16b, #0x4 + shl v16.16b, v23.16b, #0x4 + and v31.16b, v31.16b, v5.16b + and v30.16b, v30.16b, v5.16b + and v29.16b, v29.16b, v5.16b + KAI_ASM_INST(0x4e9696a4) // sdot v4.4s, v21.16b, v22.16b + KAI_ASM_INST(0x4e969643) // sdot v3.4s, v18.16b, v22.16b + and v27.16b, v27.16b, v5.16b + KAI_ASM_INST(0x4e969622) // sdot v2.4s, v17.16b, v22.16b + KAI_ASM_INST(0x4e969601) // sdot v1.4s, v16.16b, v22.16b + and v26.16b, v26.16b, v5.16b + and v25.16b, v25.16b, v5.16b + and v23.16b, v23.16b, v5.16b + KAI_ASM_INST(0x4e949404) // sdot v4.4s, v0.16b, v20.16b + KAI_ASM_INST(0x4e9497e3) // sdot v3.4s, v31.16b, v20.16b + KAI_ASM_INST(0x4e9497c2) // sdot v2.4s, v30.16b, v20.16b + KAI_ASM_INST(0x4e9497a1) // sdot v1.4s, v29.16b, v20.16b + KAI_ASM_INST(0x4e939764) // sdot v4.4s, v27.16b, v19.16b + KAI_ASM_INST(0x4e939743) // sdot v3.4s, v26.16b, v19.16b + KAI_ASM_INST(0x4e939722) // sdot v2.4s, v25.16b, v19.16b + KAI_ASM_INST(0x4e9396e1) // sdot v1.4s, v23.16b, v19.16b + bgt label_3 + ldr q18, [x24, #0x0] + ld1r { v24.4s }, [x21] + addp v4.4s, v4.4s, v3.4s + addp v2.4s, v2.4s, v1.4s + ldr q23, [x24, #0x10] + ldr q22, [x24, #0x20] + add x21, x21, #0x4 + add x20, x26, #0x4 + ld1r { v16.4s }, [x21] + ldr q17, [x24, #0x30] + cmp x23, #0x8 + ldr q21, [x24, #0x40] + ldr q20, [x24, #0x50] + mla v4.4s, v18.4s, v24.s[0] + add x24, x24, #0x60 + ld1r { v19.4s }, [x26] + ld1r { v18.4s }, [x20] + mla v2.4s, v23.4s, v24.s[0] + fmul v22.4s, v22.4s, v16.4s + fmul v17.4s, v17.4s, v16.4s + scvtf v4.4s, v4.4s + fmul v16.4s, v4.4s, v22.4s + scvtf v2.4s, v2.4s + fmul v17.4s, v2.4s, v17.4s + fadd v16.4s, v16.4s, v21.4s + fadd v17.4s, v17.4s, v20.4s + fmax v16.4s, v16.4s, v19.4s + fmin v16.4s, v16.4s, v18.4s + fmax v17.4s, v17.4s, v19.4s + fmin v17.4s, v17.4s, v18.4s + KAI_ASM_INST(0x0ea16a10) // bfcvtn v16.4h, v16.4s + KAI_ASM_INST(0x4ea16a30) // bfcvtn2 v16.8h, v17.4s + blt label_4 + str q16, [x28, #0x0] + b label_9 +KAI_ASM_LABEL(label_4) // Partial output + mov x20, x28 + tbz x23, #2, label_6 + st1 { v16.d }[0], [x20], #0x8 + tbz x23, #1, label_5 + st1 { v16.s }[2], [x20], #0x4 + tbz x23, #0, label_8 + st1 { v16.h }[6], [x20] + b label_8 +KAI_ASM_LABEL(label_5) // Output block 0: partial_1_4 + tbz x23, #0, label_8 + st1 { v16.h }[4], [x20] + b label_8 +KAI_ASM_LABEL(label_6) // Output block 0: partial_2_0 + tbz x23, #1, label_7 + st1 { v16.s }[0], [x20], #0x4 + tbz x23, #0, label_8 + st1 { v16.h }[2], [x20] + b label_8 +KAI_ASM_LABEL(label_7) // Output block 0: partial_1_0 + st1 { v16.h }[0], [x20] +KAI_ASM_LABEL(label_8) // Output block 0: Done +KAI_ASM_LABEL(label_9) // Stores done + subs x23, x23, #0x8 + add x28, x28, #0x10 + bgt label_2 + subs x25, x25, #0x1 + add x11, x11, x13 + mov x28, x22 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c new file mode 100644 index 00000000..d4ac78fb --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ + !defined(_M_ARM64) +#error "I8mm extension and bf16 vector arithmetic required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + uint16_t* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 8; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_zp_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_rsum_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 2; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +static const size_t kai_bl = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + // Since the LHS matrix is asymmetric with per-row quantization, we must include the + // the number of bytes to hold the zero point value + lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs; + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include + // the number of bytes for the reduction sum + rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + void* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + if (m == 0) { + return; + } + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h new file mode 100644 index 00000000..1c9fb83d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qai8dxp_bf16_neon to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Asymmetric Signed 8-bit with per-row quantization (qai8dx) and packed. +/// RHS matrix: Quantized Symmetric Signed 4-bit with per-channel quantization (qsi4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(uint16_t) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + void* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S new file mode 100644 index 00000000..c78a94da --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S @@ -0,0 +1,733 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm) + stp x20, x21, [sp, -144]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d12, d13, [sp, 88] + stp d14, d15, [sp, 104] + stp d8, d9, [sp, 120] + mov x7, #0x80 + movi v1.16b, #0xf0 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x8, [x0, #0x38] + ldr x17, [x0, #0x8] + ldr x16, [x0, #0x10] + ldr x15, [x0, #0x30] + ldr x14, [x0, #0x0] + mov x13, x20 + ldr x12, [x0, #0x20] + madd x7, x8, x7, x21 + ldr x11, [x0, #0x18] + cmp x13, #0x8 + blt label_12 +KAI_ASM_LABEL(label_1) // Row loop + mov x10, x16 + mov x9, x15 + add x28, x14, x12, LSL #3 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x17 + movi v12.4s, #0x0 + movi v10.4s, #0x0 + mov x21, x8 + movi v13.4s, #0x0 + movi v0.4s, #0x0 + movi v15.4s, #0x0 + movi v31.4s, #0x0 + add x20, x22, x7 + movi v14.4s, #0x0 + movi v8.4s, #0x0 + movi v9.4s, #0x0 + movi v7.4s, #0x0 + movi v16.4s, #0x0 + movi v2.4s, #0x0 + movi v20.4s, #0x0 + movi v30.4s, #0x0 + movi v5.4s, #0x0 + movi v17.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q21, [x10, #0x0] + ldr q3, [x10, #0x10] + subs x21, x21, #0x1 + ldr q26, [x10, #0x20] + ldr q4, [x10, #0x30] + ldr q6, [x22, #0x0] + ldr q29, [x22, #0x10] + ldr q24, [x20, #0x0] + ldr q27, [x20, #0x10] + shl v23.16b, v21.16b, #0x4 + shl v19.16b, v3.16b, #0x4 + ldr q18, [x10, #0x40] + ldr q28, [x10, #0x50] + shl v22.16b, v26.16b, #0x4 + shl v11.16b, v4.16b, #0x4 + ldr q25, [x10, #0x60] + and v21.16b, v21.16b, v1.16b + and v3.16b, v3.16b, v1.16b + KAI_ASM_INST(0x4e97a4cc) // smmla v12.4s, v6.16b, v23.16b + KAI_ASM_INST(0x4e93a4cd) // smmla v13.4s, v6.16b, v19.16b + KAI_ASM_INST(0x4e97a7af) // smmla v15.4s, v29.16b, v23.16b + and v26.16b, v26.16b, v1.16b + KAI_ASM_INST(0x4e96a4ca) // smmla v10.4s, v6.16b, v22.16b + KAI_ASM_INST(0x4e8ba4c0) // smmla v0.4s, v6.16b, v11.16b + ldr q6, [x10, #0x70] + and v4.16b, v4.16b, v1.16b + KAI_ASM_INST(0x4e93a7ae) // smmla v14.4s, v29.16b, v19.16b + KAI_ASM_INST(0x4e96a7bf) // smmla v31.4s, v29.16b, v22.16b + add x10, x10, #0x80 + KAI_ASM_INST(0x4e8ba7a8) // smmla v8.4s, v29.16b, v11.16b + ldr q29, [x22, #0x20] + KAI_ASM_INST(0x4e97a709) // smmla v9.4s, v24.16b, v23.16b + KAI_ASM_INST(0x4e93a710) // smmla v16.4s, v24.16b, v19.16b + KAI_ASM_INST(0x4e96a707) // smmla v7.4s, v24.16b, v22.16b + KAI_ASM_INST(0x4e8ba702) // smmla v2.4s, v24.16b, v11.16b + ldr q24, [x22, #0x30] + KAI_ASM_INST(0x4e97a774) // smmla v20.4s, v27.16b, v23.16b + ldr q23, [x20, #0x20] + KAI_ASM_INST(0x4e93a765) // smmla v5.4s, v27.16b, v19.16b + ldr q19, [x20, #0x30] + KAI_ASM_INST(0x4e96a77e) // smmla v30.4s, v27.16b, v22.16b + ldr q22, [x22, #0x40] + KAI_ASM_INST(0x4e8ba771) // smmla v17.4s, v27.16b, v11.16b + ldr q11, [x22, #0x50] + shl v27.16b, v18.16b, #0x4 + and v18.16b, v18.16b, v1.16b + KAI_ASM_INST(0x4e9ba7ac) // smmla v12.4s, v29.16b, v27.16b + KAI_ASM_INST(0x4e9ba70f) // smmla v15.4s, v24.16b, v27.16b + KAI_ASM_INST(0x4e9ba6e9) // smmla v9.4s, v23.16b, v27.16b + KAI_ASM_INST(0x4e9ba674) // smmla v20.4s, v19.16b, v27.16b + shl v27.16b, v28.16b, #0x4 + and v28.16b, v28.16b, v1.16b + KAI_ASM_INST(0x4e9ba7ad) // smmla v13.4s, v29.16b, v27.16b + KAI_ASM_INST(0x4e9ba70e) // smmla v14.4s, v24.16b, v27.16b + KAI_ASM_INST(0x4e9ba6f0) // smmla v16.4s, v23.16b, v27.16b + KAI_ASM_INST(0x4e9ba665) // smmla v5.4s, v19.16b, v27.16b + shl v27.16b, v25.16b, #0x4 + KAI_ASM_INST(0x4e95a6cc) // smmla v12.4s, v22.16b, v21.16b + KAI_ASM_INST(0x4e95a56f) // smmla v15.4s, v11.16b, v21.16b + and v25.16b, v25.16b, v1.16b + KAI_ASM_INST(0x4e9ba7aa) // smmla v10.4s, v29.16b, v27.16b + KAI_ASM_INST(0x4e9ba71f) // smmla v31.4s, v24.16b, v27.16b + KAI_ASM_INST(0x4e9ba6e7) // smmla v7.4s, v23.16b, v27.16b + KAI_ASM_INST(0x4e9ba67e) // smmla v30.4s, v19.16b, v27.16b + shl v27.16b, v6.16b, #0x4 + KAI_ASM_INST(0x4e83a6cd) // smmla v13.4s, v22.16b, v3.16b + KAI_ASM_INST(0x4e83a56e) // smmla v14.4s, v11.16b, v3.16b + and v6.16b, v6.16b, v1.16b + KAI_ASM_INST(0x4e9ba7a0) // smmla v0.4s, v29.16b, v27.16b + ldr q29, [x20, #0x40] + KAI_ASM_INST(0x4e9ba708) // smmla v8.4s, v24.16b, v27.16b + ldr q24, [x20, #0x50] + KAI_ASM_INST(0x4e9ba6e2) // smmla v2.4s, v23.16b, v27.16b + ldr q23, [x22, #0x60] + KAI_ASM_INST(0x4e9ba671) // smmla v17.4s, v19.16b, v27.16b + ldr q19, [x22, #0x70] + ldr q27, [x20, #0x60] + KAI_ASM_INST(0x4e9aa6ca) // smmla v10.4s, v22.16b, v26.16b + KAI_ASM_INST(0x4e9aa57f) // smmla v31.4s, v11.16b, v26.16b + add x22, x22, #0x80 + KAI_ASM_INST(0x4e95a7a9) // smmla v9.4s, v29.16b, v21.16b + KAI_ASM_INST(0x4e83a7b0) // smmla v16.4s, v29.16b, v3.16b + KAI_ASM_INST(0x4e84a6c0) // smmla v0.4s, v22.16b, v4.16b + ldr q22, [x20, #0x70] + KAI_ASM_INST(0x4e84a568) // smmla v8.4s, v11.16b, v4.16b + add x20, x20, #0x80 + KAI_ASM_INST(0x4e9aa7a7) // smmla v7.4s, v29.16b, v26.16b + KAI_ASM_INST(0x4e84a7a2) // smmla v2.4s, v29.16b, v4.16b + KAI_ASM_INST(0x4e95a714) // smmla v20.4s, v24.16b, v21.16b + KAI_ASM_INST(0x4e83a705) // smmla v5.4s, v24.16b, v3.16b + KAI_ASM_INST(0x4e9aa71e) // smmla v30.4s, v24.16b, v26.16b + KAI_ASM_INST(0x4e84a711) // smmla v17.4s, v24.16b, v4.16b + KAI_ASM_INST(0x4e92a6ec) // smmla v12.4s, v23.16b, v18.16b + KAI_ASM_INST(0x4e9ca6ed) // smmla v13.4s, v23.16b, v28.16b + KAI_ASM_INST(0x4e99a6ea) // smmla v10.4s, v23.16b, v25.16b + KAI_ASM_INST(0x4e86a6e0) // smmla v0.4s, v23.16b, v6.16b + KAI_ASM_INST(0x4e92a66f) // smmla v15.4s, v19.16b, v18.16b + KAI_ASM_INST(0x4e9ca66e) // smmla v14.4s, v19.16b, v28.16b + KAI_ASM_INST(0x4e99a67f) // smmla v31.4s, v19.16b, v25.16b + KAI_ASM_INST(0x4e86a668) // smmla v8.4s, v19.16b, v6.16b + KAI_ASM_INST(0x4e92a769) // smmla v9.4s, v27.16b, v18.16b + KAI_ASM_INST(0x4e9ca770) // smmla v16.4s, v27.16b, v28.16b + KAI_ASM_INST(0x4e99a767) // smmla v7.4s, v27.16b, v25.16b + KAI_ASM_INST(0x4e86a762) // smmla v2.4s, v27.16b, v6.16b + KAI_ASM_INST(0x4e92a6d4) // smmla v20.4s, v22.16b, v18.16b + KAI_ASM_INST(0x4e9ca6c5) // smmla v5.4s, v22.16b, v28.16b + KAI_ASM_INST(0x4e99a6de) // smmla v30.4s, v22.16b, v25.16b + KAI_ASM_INST(0x4e86a6d1) // smmla v17.4s, v22.16b, v6.16b + bgt label_3 + ldr q22, [x10, #0x0] + ldr q4, [x10, #0x10] + uzp1 v3.2d, v12.2d, v13.2d + uzp2 v27.2d, v12.2d, v13.2d + ld1 { v18.4s }, [x22] + ldr q6, [x10, #0x20] + uzp1 v26.2d, v10.2d, v0.2d + uzp2 v24.2d, v10.2d, v0.2d + ldr q11, [x10, #0x30] + uzp1 v25.2d, v15.2d, v14.2d + uzp2 v29.2d, v15.2d, v14.2d + add x22, x22, #0x10 + ldr q13, [x22, #0x0] + uzp1 v28.2d, v31.2d, v8.2d + uzp2 v23.2d, v31.2d, v8.2d + add x10, x10, #0x40 + mla v3.4s, v22.4s, v18.s[0] + mla v26.4s, v4.4s, v18.s[0] + mla v27.4s, v22.4s, v18.s[1] + mla v24.4s, v4.4s, v18.s[1] + mla v25.4s, v22.4s, v18.s[2] + mla v28.4s, v4.4s, v18.s[2] + fmul v10.4s, v6.4s, v13.s[0] + mla v29.4s, v22.4s, v18.s[3] + mla v23.4s, v4.4s, v18.s[3] + fmul v0.4s, v11.4s, v13.s[0] + scvtf v3.4s, v3.4s + scvtf v26.4s, v26.4s + fmul v15.4s, v6.4s, v13.s[1] + scvtf v27.4s, v27.4s + fmul v21.4s, v11.4s, v13.s[1] + scvtf v24.4s, v24.4s + fmul v14.4s, v6.4s, v13.s[2] + scvtf v25.4s, v25.4s + fmul v31.4s, v11.4s, v13.s[2] + scvtf v28.4s, v28.4s + fmul v18.4s, v6.4s, v13.s[3] + scvtf v29.4s, v29.4s + fmul v19.4s, v11.4s, v13.s[3] + scvtf v23.4s, v23.4s + fmul v12.4s, v3.4s, v10.4s + fmul v10.4s, v26.4s, v0.4s + fmul v13.4s, v27.4s, v15.4s + fmul v0.4s, v24.4s, v21.4s + fmul v15.4s, v25.4s, v14.4s + fmul v31.4s, v28.4s, v31.4s + fmul v14.4s, v29.4s, v18.4s + fmul v8.4s, v23.4s, v19.4s + ld1 { v21.4s }, [x20] + uzp1 v18.2d, v9.2d, v16.2d + uzp2 v3.2d, v9.2d, v16.2d + add x20, x20, #0x10 + ldr q16, [x20, #0x0] + uzp1 v26.2d, v7.2d, v2.2d + uzp2 v25.2d, v7.2d, v2.2d + uzp1 v24.2d, v20.2d, v5.2d + uzp2 v29.2d, v20.2d, v5.2d + uzp1 v28.2d, v30.2d, v17.2d + uzp2 v27.2d, v30.2d, v17.2d + mla v18.4s, v22.4s, v21.s[0] + mla v26.4s, v4.4s, v21.s[0] + mla v3.4s, v22.4s, v21.s[1] + fmul v23.4s, v6.4s, v16.s[0] + mla v25.4s, v4.4s, v21.s[1] + mla v24.4s, v22.4s, v21.s[2] + fmul v7.4s, v11.4s, v16.s[0] + mla v28.4s, v4.4s, v21.s[2] + mla v29.4s, v22.4s, v21.s[3] + fmul v22.4s, v6.4s, v16.s[1] + mla v27.4s, v4.4s, v21.s[3] + scvtf v18.4s, v18.4s + scvtf v26.4s, v26.4s + scvtf v3.4s, v3.4s + fmul v2.4s, v11.4s, v16.s[1] + scvtf v25.4s, v25.4s + fmul v19.4s, v6.4s, v16.s[2] + scvtf v24.4s, v24.4s + fmul v30.4s, v11.4s, v16.s[2] + scvtf v28.4s, v28.4s + fmul v5.4s, v6.4s, v16.s[3] + scvtf v29.4s, v29.4s + fmul v6.4s, v11.4s, v16.s[3] + scvtf v27.4s, v27.4s + fmul v9.4s, v18.4s, v23.4s + fmul v7.4s, v26.4s, v7.4s + fmul v16.4s, v3.4s, v22.4s + fmul v2.4s, v25.4s, v2.4s + fmul v20.4s, v24.4s, v19.4s + fmul v30.4s, v28.4s, v30.4s + fmul v5.4s, v29.4s, v5.4s + fmul v17.4s, v27.4s, v6.4s + ldr q25, [x10, #0x0] + ldr q18, [x10, #0x10] + add x20, x11, #0x4 + cmp x9, #0x8 + ld1r { v24.4s }, [x11] + ld1r { v27.4s }, [x20] + add x10, x10, #0x20 + fadd v12.4s, v12.4s, v25.4s + fadd v13.4s, v13.4s, v25.4s + fadd v15.4s, v15.4s, v25.4s + fadd v14.4s, v14.4s, v25.4s + fadd v9.4s, v9.4s, v25.4s + fadd v16.4s, v16.4s, v25.4s + fadd v20.4s, v20.4s, v25.4s + fadd v5.4s, v5.4s, v25.4s + fadd v10.4s, v10.4s, v18.4s + fadd v0.4s, v0.4s, v18.4s + fadd v31.4s, v31.4s, v18.4s + fadd v8.4s, v8.4s, v18.4s + fadd v7.4s, v7.4s, v18.4s + fadd v2.4s, v2.4s, v18.4s + fadd v30.4s, v30.4s, v18.4s + fadd v17.4s, v17.4s, v18.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v16.4s, v16.4s, v24.4s + fmax v20.4s, v20.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmin v12.4s, v12.4s, v27.4s + fmax v10.4s, v10.4s, v24.4s + fmin v13.4s, v13.4s, v27.4s + fmax v0.4s, v0.4s, v24.4s + fmin v15.4s, v15.4s, v27.4s + fmax v31.4s, v31.4s, v24.4s + fmin v14.4s, v14.4s, v27.4s + fmax v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v27.4s + fmax v7.4s, v7.4s, v24.4s + fmin v16.4s, v16.4s, v27.4s + fmax v2.4s, v2.4s, v24.4s + fmin v20.4s, v20.4s, v27.4s + fmax v30.4s, v30.4s, v24.4s + fmin v5.4s, v5.4s, v27.4s + fmax v17.4s, v17.4s, v24.4s + fmin v10.4s, v10.4s, v27.4s + fmin v0.4s, v0.4s, v27.4s + fmin v31.4s, v31.4s, v27.4s + fmin v8.4s, v8.4s, v27.4s + fmin v7.4s, v7.4s, v27.4s + fmin v2.4s, v2.4s, v27.4s + fmin v30.4s, v30.4s, v27.4s + fmin v17.4s, v17.4s, v27.4s + KAI_ASM_INST(0x0ea16997) // bfcvtn v23.4h, v12.4s + KAI_ASM_INST(0x0ea169b6) // bfcvtn v22.4h, v13.4s + KAI_ASM_INST(0x0ea169e6) // bfcvtn v6.4h, v15.4s + KAI_ASM_INST(0x0ea169cc) // bfcvtn v12.4h, v14.4s + KAI_ASM_INST(0x0ea16939) // bfcvtn v25.4h, v9.4s + KAI_ASM_INST(0x0ea16a12) // bfcvtn v18.4h, v16.4s + KAI_ASM_INST(0x0ea16a8e) // bfcvtn v14.4h, v20.4s + KAI_ASM_INST(0x0ea168a9) // bfcvtn v9.4h, v5.4s + KAI_ASM_INST(0x4ea16957) // bfcvtn2 v23.8h, v10.4s + KAI_ASM_INST(0x4ea16816) // bfcvtn2 v22.8h, v0.4s + KAI_ASM_INST(0x4ea16be6) // bfcvtn2 v6.8h, v31.4s + KAI_ASM_INST(0x4ea1690c) // bfcvtn2 v12.8h, v8.4s + KAI_ASM_INST(0x4ea168f9) // bfcvtn2 v25.8h, v7.4s + KAI_ASM_INST(0x4ea16852) // bfcvtn2 v18.8h, v2.4s + KAI_ASM_INST(0x4ea16bce) // bfcvtn2 v14.8h, v30.4s + KAI_ASM_INST(0x4ea16a29) // bfcvtn2 v9.8h, v17.4s + blt label_6 + mov x20, x14 + str q23, [x20, #0x0] + add x20, x20, x12 + str q22, [x20, #0x0] + add x20, x20, x12 + str q6, [x20, #0x0] + add x20, x20, x12 + str q12, [x20, #0x0] + add x20, x20, x12 + str q25, [x20, #0x0] + add x20, x20, x12 + str q18, [x20, #0x0] + add x20, x20, x12 + str q14, [x20, #0x0] + add x20, x20, x12 + str q9, [x20, #0x0] + b label_11 +KAI_ASM_LABEL(label_6) // Partial output + mov x27, x14 + add x26, x27, x12, LSL #2 + add x25, x26, x12, LSL #1 + add x24, x26, x12 + add x23, x25, x12 + add x22, x27, x12, LSL #1 + add x21, x27, x12 + add x20, x22, x12 + tbz x9, #2, label_8 + st1 { v9.d }[0], [x23], #0x8 + st1 { v14.d }[0], [x25], #0x8 + st1 { v18.d }[0], [x24], #0x8 + st1 { v25.d }[0], [x26], #0x8 + st1 { v12.d }[0], [x20], #0x8 + st1 { v6.d }[0], [x22], #0x8 + st1 { v22.d }[0], [x21], #0x8 + st1 { v23.d }[0], [x27], #0x8 + tbz x9, #1, label_7 + st1 { v9.s }[2], [x23], #0x4 + st1 { v14.s }[2], [x25], #0x4 + st1 { v18.s }[2], [x24], #0x4 + st1 { v25.s }[2], [x26], #0x4 + st1 { v12.s }[2], [x20], #0x4 + st1 { v6.s }[2], [x22], #0x4 + st1 { v22.s }[2], [x21], #0x4 + st1 { v23.s }[2], [x27], #0x4 + tbz x9, #0, label_10 + st1 { v9.h }[6], [x23] + st1 { v14.h }[6], [x25] + st1 { v18.h }[6], [x24] + st1 { v25.h }[6], [x26] + st1 { v12.h }[6], [x20] + st1 { v6.h }[6], [x22] + st1 { v22.h }[6], [x21] + st1 { v23.h }[6], [x27] + b label_10 +KAI_ASM_LABEL(label_7) // Output block 0: partial_1_4 + tbz x9, #0, label_10 + st1 { v9.h }[4], [x23] + st1 { v14.h }[4], [x25] + st1 { v18.h }[4], [x24] + st1 { v25.h }[4], [x26] + st1 { v12.h }[4], [x20] + st1 { v6.h }[4], [x22] + st1 { v22.h }[4], [x21] + st1 { v23.h }[4], [x27] + b label_10 +KAI_ASM_LABEL(label_8) // Output block 0: partial_2_0 + tbz x9, #1, label_9 + st1 { v9.s }[0], [x23], #0x4 + st1 { v14.s }[0], [x25], #0x4 + st1 { v18.s }[0], [x24], #0x4 + st1 { v25.s }[0], [x26], #0x4 + st1 { v12.s }[0], [x20], #0x4 + st1 { v6.s }[0], [x22], #0x4 + st1 { v22.s }[0], [x21], #0x4 + st1 { v23.s }[0], [x27], #0x4 + tbz x9, #0, label_10 + st1 { v9.h }[2], [x23] + st1 { v14.h }[2], [x25] + st1 { v18.h }[2], [x24] + st1 { v25.h }[2], [x26] + st1 { v12.h }[2], [x20] + st1 { v6.h }[2], [x22] + st1 { v22.h }[2], [x21] + st1 { v23.h }[2], [x27] + b label_10 +KAI_ASM_LABEL(label_9) // Output block 0: partial_1_0 + st1 { v9.h }[0], [x23] + st1 { v14.h }[0], [x25] + st1 { v18.h }[0], [x24] + st1 { v25.h }[0], [x26] + st1 { v12.h }[0], [x20] + st1 { v6.h }[0], [x22] + st1 { v22.h }[0], [x21] + st1 { v23.h }[0], [x27] +KAI_ASM_LABEL(label_10) // Output block 0: Done +KAI_ASM_LABEL(label_11) // Output stage exit + subs x9, x9, #0x8 + add x14, x14, #0x10 + bgt label_2 + mov x20, #0x2 + sub x13, x13, #0x8 + cmp x13, #0x8 + mov x14, x28 + madd x17, x20, x7, x17 + bge label_1 +KAI_ASM_LABEL(label_12) // Row loop skip + cbz x13, label_23 +KAI_ASM_LABEL(label_13) // Row tail: Row loop + mov x26, x16 + mov x25, x15 + add x24, x14, x12, LSL #2 +KAI_ASM_LABEL(label_14) // Row tail: Column loop + mov x22, x17 + movi v12.4s, #0x0 + movi v10.4s, #0x0 + mov x20, x8 + movi v13.4s, #0x0 + movi v0.4s, #0x0 + movi v15.4s, #0x0 + movi v31.4s, #0x0 + movi v14.4s, #0x0 + movi v8.4s, #0x0 +KAI_ASM_LABEL(label_15) // Row tail: Sub block loop + ldr q7, [x26, #0x0] + ldr q6, [x26, #0x10] + subs x20, x20, #0x1 + ldr q5, [x26, #0x20] + ldr q4, [x26, #0x30] + ldr q3, [x22, #0x0] + ldr q2, [x22, #0x10] + ldr q21, [x26, #0x40] + ldr q9, [x26, #0x50] + shl v19.16b, v7.16b, #0x4 + shl v18.16b, v6.16b, #0x4 + ldr q28, [x26, #0x60] + ldr q26, [x26, #0x70] + shl v29.16b, v5.16b, #0x4 + shl v30.16b, v4.16b, #0x4 + ldr q25, [x22, #0x20] + ldr q24, [x22, #0x30] + and v7.16b, v7.16b, v1.16b + and v6.16b, v6.16b, v1.16b + ldr q23, [x22, #0x40] + ldr q22, [x22, #0x50] + KAI_ASM_INST(0x4e93a46c) // smmla v12.4s, v3.16b, v19.16b + KAI_ASM_INST(0x4e92a46d) // smmla v13.4s, v3.16b, v18.16b + ldr q11, [x22, #0x60] + ldr q20, [x22, #0x70] + KAI_ASM_INST(0x4e9da46a) // smmla v10.4s, v3.16b, v29.16b + KAI_ASM_INST(0x4e9ea460) // smmla v0.4s, v3.16b, v30.16b + KAI_ASM_INST(0x4e93a44f) // smmla v15.4s, v2.16b, v19.16b + KAI_ASM_INST(0x4e92a44e) // smmla v14.4s, v2.16b, v18.16b + shl v19.16b, v21.16b, #0x4 + add x26, x26, #0x80 + KAI_ASM_INST(0x4e9da45f) // smmla v31.4s, v2.16b, v29.16b + KAI_ASM_INST(0x4e9ea448) // smmla v8.4s, v2.16b, v30.16b + shl v18.16b, v9.16b, #0x4 + add x22, x22, #0x80 + shl v27.16b, v28.16b, #0x4 + shl v29.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e93a72c) // smmla v12.4s, v25.16b, v19.16b + and v5.16b, v5.16b, v1.16b + and v4.16b, v4.16b, v1.16b + KAI_ASM_INST(0x4e92a72d) // smmla v13.4s, v25.16b, v18.16b + KAI_ASM_INST(0x4e93a70f) // smmla v15.4s, v24.16b, v19.16b + KAI_ASM_INST(0x4e92a70e) // smmla v14.4s, v24.16b, v18.16b + and v21.16b, v21.16b, v1.16b + KAI_ASM_INST(0x4e9ba72a) // smmla v10.4s, v25.16b, v27.16b + KAI_ASM_INST(0x4e9da720) // smmla v0.4s, v25.16b, v29.16b + and v9.16b, v9.16b, v1.16b + KAI_ASM_INST(0x4e9ba71f) // smmla v31.4s, v24.16b, v27.16b + KAI_ASM_INST(0x4e9da708) // smmla v8.4s, v24.16b, v29.16b + and v28.16b, v28.16b, v1.16b + KAI_ASM_INST(0x4e87a6ec) // smmla v12.4s, v23.16b, v7.16b + KAI_ASM_INST(0x4e86a6ed) // smmla v13.4s, v23.16b, v6.16b + and v26.16b, v26.16b, v1.16b + KAI_ASM_INST(0x4e87a6cf) // smmla v15.4s, v22.16b, v7.16b + KAI_ASM_INST(0x4e86a6ce) // smmla v14.4s, v22.16b, v6.16b + KAI_ASM_INST(0x4e85a6ea) // smmla v10.4s, v23.16b, v5.16b + KAI_ASM_INST(0x4e84a6e0) // smmla v0.4s, v23.16b, v4.16b + KAI_ASM_INST(0x4e85a6df) // smmla v31.4s, v22.16b, v5.16b + KAI_ASM_INST(0x4e84a6c8) // smmla v8.4s, v22.16b, v4.16b + KAI_ASM_INST(0x4e95a56c) // smmla v12.4s, v11.16b, v21.16b + KAI_ASM_INST(0x4e89a56d) // smmla v13.4s, v11.16b, v9.16b + KAI_ASM_INST(0x4e95a68f) // smmla v15.4s, v20.16b, v21.16b + KAI_ASM_INST(0x4e89a68e) // smmla v14.4s, v20.16b, v9.16b + KAI_ASM_INST(0x4e9ca56a) // smmla v10.4s, v11.16b, v28.16b + KAI_ASM_INST(0x4e9aa560) // smmla v0.4s, v11.16b, v26.16b + KAI_ASM_INST(0x4e9ca69f) // smmla v31.4s, v20.16b, v28.16b + KAI_ASM_INST(0x4e9aa688) // smmla v8.4s, v20.16b, v26.16b + bgt label_15 + ldr q20, [x26, #0x0] + ldr q19, [x26, #0x10] + uzp1 v3.2d, v12.2d, v13.2d + uzp2 v2.2d, v12.2d, v13.2d + ld1 { v18.4s }, [x22] + ldr q29, [x26, #0x20] + uzp1 v13.2d, v10.2d, v0.2d + uzp2 v4.2d, v10.2d, v0.2d + ldr q21, [x26, #0x30] + uzp1 v28.2d, v15.2d, v14.2d + uzp2 v7.2d, v15.2d, v14.2d + add x22, x22, #0x10 + ldr q9, [x22, #0x0] + uzp1 v25.2d, v31.2d, v8.2d + uzp2 v24.2d, v31.2d, v8.2d + add x26, x26, #0x40 + mla v3.4s, v20.4s, v18.s[0] + mla v13.4s, v19.4s, v18.s[0] + mla v2.4s, v20.4s, v18.s[1] + mla v4.4s, v19.4s, v18.s[1] + mla v28.4s, v20.4s, v18.s[2] + mla v25.4s, v19.4s, v18.s[2] + fmul v23.4s, v29.4s, v9.s[0] + mla v7.4s, v20.4s, v18.s[3] + mla v24.4s, v19.4s, v18.s[3] + fmul v22.4s, v21.4s, v9.s[0] + scvtf v3.4s, v3.4s + scvtf v13.4s, v13.4s + fmul v30.4s, v29.4s, v9.s[1] + scvtf v2.4s, v2.4s + fmul v20.4s, v21.4s, v9.s[1] + scvtf v4.4s, v4.4s + fmul v19.4s, v29.4s, v9.s[2] + scvtf v28.4s, v28.4s + fmul v18.4s, v21.4s, v9.s[2] + scvtf v25.4s, v25.4s + fmul v27.4s, v29.4s, v9.s[3] + scvtf v7.4s, v7.4s + fmul v26.4s, v21.4s, v9.s[3] + scvtf v24.4s, v24.4s + fmul v12.4s, v3.4s, v23.4s + fmul v10.4s, v13.4s, v22.4s + fmul v13.4s, v2.4s, v30.4s + fmul v0.4s, v4.4s, v20.4s + fmul v15.4s, v28.4s, v19.4s + fmul v31.4s, v25.4s, v18.4s + fmul v14.4s, v7.4s, v27.4s + fmul v8.4s, v24.4s, v26.4s + ldr q19, [x26, #0x0] + ldr q18, [x26, #0x10] + add x20, x11, #0x4 + cmp x25, #0x8 + ld1r { v22.4s }, [x11] + ld1r { v30.4s }, [x20] + add x26, x26, #0x20 + fadd v12.4s, v12.4s, v19.4s + fadd v13.4s, v13.4s, v19.4s + fadd v15.4s, v15.4s, v19.4s + fadd v14.4s, v14.4s, v19.4s + fadd v10.4s, v10.4s, v18.4s + fadd v0.4s, v0.4s, v18.4s + fadd v31.4s, v31.4s, v18.4s + fadd v8.4s, v8.4s, v18.4s + fmax v12.4s, v12.4s, v22.4s + fmax v13.4s, v13.4s, v22.4s + fmax v15.4s, v15.4s, v22.4s + fmax v14.4s, v14.4s, v22.4s + fmax v10.4s, v10.4s, v22.4s + fmin v12.4s, v12.4s, v30.4s + fmax v0.4s, v0.4s, v22.4s + fmin v13.4s, v13.4s, v30.4s + fmin v15.4s, v15.4s, v30.4s + fmax v31.4s, v31.4s, v22.4s + fmin v14.4s, v14.4s, v30.4s + fmax v8.4s, v8.4s, v22.4s + fmin v10.4s, v10.4s, v30.4s + fmin v0.4s, v0.4s, v30.4s + KAI_ASM_INST(0x0ea16993) // bfcvtn v19.4h, v12.4s + fmin v31.4s, v31.4s, v30.4s + KAI_ASM_INST(0x0ea169b2) // bfcvtn v18.4h, v13.4s + fmin v8.4s, v8.4s, v30.4s + KAI_ASM_INST(0x0ea169e5) // bfcvtn v5.4h, v15.4s + KAI_ASM_INST(0x0ea169c3) // bfcvtn v3.4h, v14.4s + KAI_ASM_INST(0x4ea16953) // bfcvtn2 v19.8h, v10.4s + KAI_ASM_INST(0x4ea16812) // bfcvtn2 v18.8h, v0.4s + KAI_ASM_INST(0x4ea16be5) // bfcvtn2 v5.8h, v31.4s + KAI_ASM_INST(0x4ea16903) // bfcvtn2 v3.8h, v8.4s + blt label_17 + mov x20, x14 + cmp x13, #0x1 + str q19, [x20, #0x0] + add x20, x20, x12 + ble label_22 + cmp x13, #0x2 + str q18, [x20, #0x0] + add x20, x20, x12 + ble label_22 + cmp x13, #0x3 + str q5, [x20, #0x0] + add x20, x20, x12 + ble label_22 + str q3, [x20, #0x0] + b label_22 +KAI_ASM_LABEL(label_17) // Row tail: Partial output + mov x23, x14 + cmp x13, #0x1 + add x22, x23, x12 + csel x22, x22, x23, GT + cmp x13, #0x2 + add x21, x23, x12, LSL #1 + csel x21, x21, x22, GT + cmp x13, #0x3 + add x20, x21, x12 + csel x20, x20, x21, GT + tbz x25, #2, label_19 + st1 { v3.d }[0], [x20], #0x8 + st1 { v5.d }[0], [x21], #0x8 + st1 { v18.d }[0], [x22], #0x8 + st1 { v19.d }[0], [x23], #0x8 + tbz x25, #1, label_18 + st1 { v3.s }[2], [x20], #0x4 + st1 { v5.s }[2], [x21], #0x4 + st1 { v18.s }[2], [x22], #0x4 + st1 { v19.s }[2], [x23], #0x4 + tbz x25, #0, label_21 + st1 { v3.h }[6], [x20] + st1 { v5.h }[6], [x21] + st1 { v18.h }[6], [x22] + st1 { v19.h }[6], [x23] + b label_21 +KAI_ASM_LABEL(label_18) // Row tail: Output block 0: partial_1_4 + tbz x25, #0, label_21 + st1 { v3.h }[4], [x20] + st1 { v5.h }[4], [x21] + st1 { v18.h }[4], [x22] + st1 { v19.h }[4], [x23] + b label_21 +KAI_ASM_LABEL(label_19) // Row tail: Output block 0: partial_2_0 + tbz x25, #1, label_20 + st1 { v3.s }[0], [x20], #0x4 + st1 { v5.s }[0], [x21], #0x4 + st1 { v18.s }[0], [x22], #0x4 + st1 { v19.s }[0], [x23], #0x4 + tbz x25, #0, label_21 + st1 { v3.h }[2], [x20] + st1 { v5.h }[2], [x21] + st1 { v18.h }[2], [x22] + st1 { v19.h }[2], [x23] + b label_21 +KAI_ASM_LABEL(label_20) // Row tail: Output block 0: partial_1_0 + st1 { v3.h }[0], [x20] + st1 { v5.h }[0], [x21] + st1 { v18.h }[0], [x22] + st1 { v19.h }[0], [x23] +KAI_ASM_LABEL(label_21) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_22) // Row tail: Output stage exit + subs x25, x25, #0x8 + add x14, x14, #0x10 + bgt label_14 + subs x13, x13, #0x4 + add x17, x17, x7 + mov x14, x24 + bgt label_13 +KAI_ASM_LABEL(label_23) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d12, d13, [sp, 88] + ldp d14, d15, [sp, 104] + ldp d8, d9, [sp, 120] + ldp x20, x21, [sp], 144 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h new file mode 100644 index 00000000..cf22a222 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_bf16_qai8dxp_qsi4cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel { + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_mr_func_t get_mr; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_nr_func_t get_nr; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_kr_func_t get_kr; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_sr_func_t get_sr; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c new file mode 100644 index 00000000..175e81b3 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c @@ -0,0 +1,254 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if ( \ + !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || \ + !defined(__ARM_FEATURE_BF16_SCALAR_ARITHMETIC)) && \ + !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_BF16. +#else // Architectural features check. + +#include "kai_lhs_quant_pack_qai8dxp_bf16_neon.h" + +#include +#endif +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + static const size_t kai_k_multiple_of = 32; + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_UNUSED(kr); + KAI_UNUSED(sr); + // It always points to the beginning of the row + return (m_idx / mr) * kai_lhs_packed_stride(k, mr); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + KAI_UNUSED(kr); + KAI_UNUSED(sr); + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr); +} + +// Note: The lhs parameter type has been changed from float* to void*. +// The bfloat16 values (packed in 16 bits) will be converted to float32. +void kai_run_lhs_quant_pack_qai8dxp_bf16( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + + if (m == 0) { + return; + } + + // Now lhs is assumed to contain bfloat16 values encoded in uint16_t. + const uint16_t* src_ptr = (uint16_t const*)lhs; + + const size_t dst_stride = kai_lhs_packed_stride(k, mr); + const size_t k_internal = kai_k_roundedup(k); + const int32_t k_block_len = (int32_t)(kr / sr); + + const int32_t num_blocks_k = (int32_t)(k / k_block_len); + const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + int32_t k_idx = 0; + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + uint16x8_t zero = vdupq_n_u16(0); + // Process 8 bfloat16 values per iteration. + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + // // Load eight bfloat16 values. + uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); + uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); + uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); + float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); + float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + + // Calculate the maximum + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the minimum + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + // Get the max/min scalar values. + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); + // Process leftover elements with a scalar loop. + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); + max0 = fmaxf(src0_0, max0); + min0 = fminf(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = fminf(0.0F, min0); + const float rmax0 = fmaxf(0.0F, max0); + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = fmaxf(zero_point0, qmin); + zero_point0 = fminf(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); + + // Quantize the channels + int32_t block_idx = 0; + + if (k_block_len == 8) { + for (; block_idx < num_blocks_k; ++block_idx) { + // Clamp at the last valid k-index + const int32_t k_idx_start = block_idx * k_block_len; + + // Load eight bfloat16 values and convert them to float32. + uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); + uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); + uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); + float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); + float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + + // Scale the values. + float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); + float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); + int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); + int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); + + int16x4_t v0_s16 = vqmovn_s32(v0_s32); + int16x4_t v1_s16 = vqmovn_s32(v1_s32); + int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); + + // Add zero points. + int16_t nzp_s16 = (int16_t)nudged_zero_point0; + int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); + v_s16 = vaddq_s16(v_s16, vnzp_s16); + v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); + v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); + + int8x8_t v0_s8 = vqmovn_s16(v_s16); + vst1_s8((int8_t*)(dst_ptr), v0_s8); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + } else { + for (; block_idx < num_blocks_k; ++block_idx) { + for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx; + + // Convert the bfloat16 value to float. + const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); + + // Scale the value. + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + } + + for (; block_idx < num_blocks_k_internal; ++block_idx) { + // Left over k + for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index. + const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); + + const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); + + // Scale the value. + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row. + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params. + *((float*)(dst_ptr)) = recip_scale0; + + // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). + src_ptr += (lhs_stride / sizeof(uint16_t)); + + // Move to the next row if we have interleaved all Mr rows. + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h new file mode 100644 index 00000000..e2760ca9 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h @@ -0,0 +1,76 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS of the vector-by-matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qai8dxp_bf16( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/test/common/cpu_info.cpp b/test/common/cpu_info.cpp index 172a9937..9a2aa461 100644 --- a/test/common/cpu_info.cpp +++ b/test/common/cpu_info.cpp @@ -279,4 +279,12 @@ bool cpu_has_sme2() { return CpuInfo::current().has_sme2; } +bool cpu_has_dotprod_and_bf16() { + return cpu_has_dotprod() && cpu_has_bf16(); +} + +bool cpu_has_i8mm_and_bf16() { + return cpu_has_i8mm() && cpu_has_bf16(); +} + } // namespace kai::test diff --git a/test/common/cpu_info.hpp b/test/common/cpu_info.hpp index e9fe8601..1249e2eb 100644 --- a/test/common/cpu_info.hpp +++ b/test/common/cpu_info.hpp @@ -41,4 +41,10 @@ bool cpu_has_sme(); /// Returns a value indicating whether the current CPU supports FEAT_SME2. bool cpu_has_sme2(); +/// Returns a value indicating whether the current CPU supports FEAT_BF16 and FEAT_DotProd +bool cpu_has_dotprod_and_bf16(); + +/// Returns a value indicating whether the current CPU supports FEAT_BF16 and FEAT_I8MM +bool cpu_has_i8mm_and_bf16(); + } // namespace kai::test diff --git a/test/reference/cast.cpp b/test/reference/cast.cpp index d6728926..4c215b6e 100644 --- a/test/reference/cast.cpp +++ b/test/reference/cast.cpp @@ -33,6 +33,7 @@ Buffer cast(const void* src, size_t length) { template Buffer cast(const void* src, size_t length); template Buffer cast(const void* src, size_t length); template Buffer cast(const void* src, size_t length); +template Buffer cast(const void* src, size_t length); Buffer cast(const void* src, kai::test::DataType src_dt, DataType dst_dt, size_t height, size_t width) { const auto length = height * width; diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp index 179acc58..9057555b 100644 --- a/test/reference/fill.cpp +++ b/test/reference/fill.cpp @@ -125,5 +125,6 @@ Buffer fill_random(size_t length, uint32_t seed) { template Buffer fill_random(size_t length, uint32_t seed); template Buffer fill_random(size_t length, uint32_t seed); template Buffer fill_matrix_raw(size_t height, size_t width, std::function gen); +template Buffer fill_random(size_t length, uint32_t seed); } // namespace kai::test diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp new file mode 100644 index 00000000..b2bf5d64 --- /dev/null +++ b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp @@ -0,0 +1,358 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "test/common/bfloat16.hpp" +#include "test/common/buffer.hpp" +#include "test/common/compare.hpp" +#include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" +#include "test/common/int4.hpp" +#include "test/common/matmul_test_common.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/clamp.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/pad.hpp" +#include "test/reference/quantize.hpp" +#include "test/reference/transpose.hpp" + +namespace kai::test { + +static const std::array, 2> + variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp = {{ + {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod), + "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod_and_bf16}, + {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm), + "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm_and_bf16}, + }}; + +class MatMulTest_bf16_qai8dxp_qsi4cxp : public ::testing::TestWithParam {}; + +TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_NxK) { + const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const std::uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + // Generates input data. + const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + Buffer ref_biases_buf; + if (has_bias) { + ref_biases_buf = Buffer(fill_random(N, seed + 2)); + } + + // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul + // implementation works with FP32 accumulation and casts the result to BFP16 + const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); + const auto ref_dst_no_clamp = + matmul_nt_t_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, + ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases_buf.data() : nullptr, + nullptr, nullptr, 1); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), M * N, clamp_ratio); + const auto ref_dst_float = clamp(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to BF16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(M, K, mr, kr, sr); + Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(uint16_t); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_bf16( + rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, + reinterpret_cast(imp_packed_lhs_buf.data()) + lhs_packed_offset); + + const auto ref_rhs_qsi4_padded = pad_row( + ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); + + const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); + Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size); + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + // Runs the RHS packing micro-kernel. + kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 0; + + kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + 1, N, K, nr, kr, sr, reinterpret_cast(ref_rhs_qsi4_padded.data()), + has_bias ? reinterpret_cast(ref_biases_buf.data()) : nullptr, + reinterpret_cast(ref_rhs_scales.data()), reinterpret_cast(imp_packed_rhs_buf.data()), 0, + ¶ms); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + Buffer imp_dst_buf = Buffer(imp_dst_size); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, reinterpret_cast(imp_packed_lhs_buf.data()) + lhs_matmul_offset, + reinterpret_cast(imp_packed_rhs_buf.data()) + rhs_matmul_offset, + reinterpret_cast(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, + clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::BF16); + const auto success = + compare(reinterpret_cast(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); +} + +TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_KxN) { + const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + const uint32_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + // Generates input data. + const auto ref_lhs_bf16 = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + Buffer ref_biases_buf; + if (has_bias) { + ref_biases_buf = Buffer(fill_random(N, seed + 2)); + } + + const auto ref_lhs = cast(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits); + + // Transposed(nxk) RHS dimensions + const size_t ref_rhs_qsi4_nxk_stride = K; + + // Non-Transposed(kxn) RHS dimensions + const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); + const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = + quantize_symmetric_per_block_dynamic(ref_rhs.data(), N, K, K); + + const auto ref_rhs_qsi4 = transpose_with_padding( + ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, + ref_rhs_qsi4_kxn_size_bytes); + + const auto ref_dst_fp32_clamp = + matmul_clamp_nt_nt( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, K, has_bias ? ref_biases_buf.data() : nullptr, + std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_fp32_clamp.data(), M * N, clamp_ratio); + const auto ref_dst_float = clamp(ref_dst_fp32_clamp.data(), M * N, clamp_min, clamp_max); + + // Cast the reference output to BF16 + auto ref_dst = cast(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits); + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(uint16_t); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(M, K, mr, kr, sr); + Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qai8dxp_bf16( + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, ref_lhs_bf16.data() + lhs_offset, lhs_stride, + reinterpret_cast(imp_packed_lhs_buf.data()) + lhs_packed_offset); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsi4_padded = pad_row( + ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); + const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); + + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size); + kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 0; + kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( + 1, N, K, nr, kr, sr, reinterpret_cast(ref_rhs_qsi4_padded.data()), + has_bias ? reinterpret_cast(ref_biases_buf.data()) : nullptr, + reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs_buf.data(), 0, ¶ms); + + const auto dst_stride = N * sizeof(uint16_t); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(uint16_t); + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + Buffer imp_dst_buf = Buffer(imp_dst_size); + + const auto dst_stride_row = N * sizeof(uint16_t); + const auto dst_stride_col = sizeof(uint16_t); + + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, reinterpret_cast(imp_packed_lhs_buf.data()) + lhs_matmul_offset, + reinterpret_cast(imp_packed_rhs_buf.data()) + rhs_matmul_offset, + reinterpret_cast(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, + clamp_max); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::BF16); + const auto success = + compare(reinterpret_cast(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_bf16_qai8dxp_qsi4cxp, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.size()), + testing::Values( + MatMulShape{1, 2, 32}, // + MatMulShape{1, 3, 32}, // + MatMulShape{1, 4, 32}, // + MatMulShape{1, 5, 32}, // + MatMulShape{3, 3, 32}, // + MatMulShape{4, 4, 32}, // + MatMulShape{5, 5, 32}, // + MatMulShape{32, 64, 64}, // + MatMulShape{16, 32, 64}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 32, 32}, // + MatMulShape{77, 99, 64}), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + ), + testing::Bool()), + [](const auto& info) -> std::string { + const auto variant_idx = std::get<0>(info.param); + const auto& name = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp[variant_idx].name; + return test_description( + name, std::get(info.param), std::get<2>(info.param), std::get<3>(info.param)); + }); + +} // namespace kai::test -- GitLab From 85c4d0545b53c9ffa165cb593170b8c75e2ce04b Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Tue, 1 Jul 2025 15:02:52 +0000 Subject: [PATCH 2/9] Update kai_lhs_quant_pack_qai8dxp_bf16_neon apis to have _neon suffix Signed-off-by: Nikhil Gupta --- .../pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c | 11 ++++++----- .../pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h | 11 ++++++----- .../matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp | 16 ++++++++-------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c index 175e81b3..faaf58b9 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c @@ -37,23 +37,24 @@ inline static size_t kai_lhs_packed_stride(size_t k, size_t mr) { return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); } -size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16(size_t mr) { +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16_neon(size_t mr) { KAI_UNUSED(mr); return 1; } -size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t lhs_stride) { +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(size_t m_idx, size_t lhs_stride) { return m_idx * lhs_stride; } -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { KAI_UNUSED(kr); KAI_UNUSED(sr); // It always points to the beginning of the row return (m_idx / mr) * kai_lhs_packed_stride(k, mr); } -size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { KAI_UNUSED(kr); KAI_UNUSED(sr); const size_t num_rows = kai_roundup(m, mr) / mr; @@ -63,7 +64,7 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(size_t m, size_t k, s // Note: The lhs parameter type has been changed from float* to void*. // The bfloat16 values (packed in 16 bits) will be converted to float32. -void kai_run_lhs_quant_pack_qai8dxp_bf16( +void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, size_t lhs_stride, void* restrict lhs_packed) { KAI_ASSERT((kr % sr) == 0); diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h index e2760ca9..5f59c602 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h @@ -18,7 +18,7 @@ extern "C" { /// @param[in] mr The number of M rows to interleave on the same output row. /// /// @return the m step value -size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16(size_t mr); +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16_neon(size_t mr); /// Gets the offset in bytes for the LHS matrix (not packed) /// @@ -28,7 +28,7 @@ size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16(size_t mr); /// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) /// /// @return the offset in bytes to the LHS matrix -size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(size_t m_idx, size_t lhs_stride); /// Gets the offset in bytes for the packed LHS matrix, /// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. @@ -42,7 +42,8 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t lhs_s /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the offset in bytes to the packed LHS matrix -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); /// Gets the size in bytes for the quantized and packed LHS matrix /// @@ -53,7 +54,7 @@ size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(size_t m_idx, size_ /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the packed LHS matrix size in bytes -size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(size_t m, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr); /// Run the micro-kernel to quantize and pack the LHS matrix. /// @@ -67,7 +68,7 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(size_t m, size_t k, s /// @param[in] lhs LHS of the vector-by-matrix. /// @param[in] lhs_stride Stride in bytes between two rows of LHS. /// @param[out] lhs_packed The quantized and packed LHS matrix. -void kai_run_lhs_quant_pack_qai8dxp_bf16( +void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed); diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp index b2bf5d64..9131fdb5 100644 --- a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp @@ -121,18 +121,18 @@ TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_NxK) { // Runs the LHS packing micro-kernel. const auto lhs_start_row = rect.start_row(); - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(M, K, mr, kr, sr); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); auto lhs_stride = K * sizeof(uint16_t); - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, lhs_stride); - auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, K, mr, kr, sr); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - kai_run_lhs_quant_pack_qai8dxp_bf16( + kai_run_lhs_quant_pack_qai8dxp_bf16_neon( rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, reinterpret_cast(imp_packed_lhs_buf.data()) + lhs_packed_offset); @@ -261,14 +261,14 @@ TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_KxN) { size_t lhs_stride = K * sizeof(uint16_t); // Runs the LHS packing micro-kernel. - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16(M, K, mr, kr, sr); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); - auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, lhs_stride); - auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16(lhs_start_row, K, mr, kr, sr); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); - kai_run_lhs_quant_pack_qai8dxp_bf16( + kai_run_lhs_quant_pack_qai8dxp_bf16_neon( rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, ref_lhs_bf16.data() + lhs_offset, lhs_stride, reinterpret_cast(imp_packed_lhs_buf.data()) + lhs_packed_offset); -- GitLab From 7d9d8c62df91792f0f161bdabca2be716bb9c653 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Fri, 4 Jul 2025 17:00:54 +0100 Subject: [PATCH 3/9] address review comments Signed-off-by: Evie Wright --- CHANGELOG.md | 5 +- CMakeLists.txt | 33 +- kai/ukernels/matmul/BUILD.bazel | 30 +- ...6_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c | 7 +- ...bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c | 7 +- .../kai_lhs_quant_pack_qai8dxp_bf16_neon.c | 516 ++++++++++++++---- ...matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp | 3 +- 7 files changed, 421 insertions(+), 180 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03ebd525..a9cc0096 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release -- Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_I8MM. -- Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_DotProd. +- New Advanced SIMD micro-kernels: + - Matrix multiplication (MxN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_I8MM. + - Matrix multiplication (1xN) Micro-kernels of QAI8DX LHS and QSI4CX RHS with BF16 output, optimized for FEAT_DotProd. ## v1.11.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index fdc4183a..1656e376 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -139,12 +139,7 @@ set(KLEIDIAI_FILES_NEON_FP16_I8MM ${KLEIDIAI_FILES_NEON_FP16_I8MM_ASM} ) -set(KLEIDIAI_FILES_NEON_BF16_ASM - kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c -) - set(KLEIDIAI_FILES_NEON_BF16 - ${KLEIDIAI_FILES_NEON_BF16_ASM} kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c @@ -164,6 +159,7 @@ set(KLEIDIAI_FILES_NEON_ASM kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c ) set(KLEIDIAI_FILES_NEON @@ -199,6 +195,8 @@ set(KLEIDIAI_FILES_NEON_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S ) set(KLEIDIAI_FILES_NEON_DOTPROD @@ -227,6 +225,8 @@ set(KLEIDIAI_FILES_NEON_I8MM_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S ) set(KLEIDIAI_FILES_NEON_I8MM @@ -317,24 +317,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ) -set(KLEIDIAI_FILES_NEON_BF16_DOTPROD_ASM - kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c - kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod_asm.S -) - -set(KLEIDIAI_FILES_NEON_BF16_I8MM_ASM - kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c - kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm_asm.S -) - -set(KLEIDIAI_FILES_NEON_BF16_DOTPROD - ${KLEIDIAI_FILES_NEON_BF16_DOTPROD_ASM} -) - -set(KLEIDIAI_FILES_NEON_BF16_I8MM - ${KLEIDIAI_FILES_NEON_BF16_I8MM_ASM} -) - add_library(kleidiai) add_library(${PROJECT_NAME}::kleidiai ALIAS kleidiai) @@ -352,8 +334,6 @@ if(NOT MSVC) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_DOTPROD}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_I8MM}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) @@ -365,9 +345,6 @@ if(NOT MSVC) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - # Use -fno-tree-vectorize option to disable compiler based vectorization set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") set_source_files_properties(${KLEIDIAI_FILES_SME2} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index b961b9e7..5535a57e 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -30,6 +30,7 @@ SCALAR_KERNELS = [ # buildifier: keep sorted NEON_KERNELS = [ + "pack/kai_lhs_quant_pack_qai8dxp_bf16_neon", "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", @@ -62,7 +63,6 @@ BF16_KERNELS = [ "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla", "pack/kai_lhs_quant_pack_bf16p1x4_f32_neon", "pack/kai_lhs_quant_pack_bf16p8x4_f32_neon", - "pack/kai_lhs_quant_pack_qai8dxp_bf16_neon", "pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon", ] @@ -111,6 +111,7 @@ DOTPROD_KERNELS = [ # buildifier: keep sorted DOTPROD_KERNELS_ASM = [ + "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", @@ -136,6 +137,7 @@ I8MM_KERNELS = [ # buildifier: keep sorted I8MM_KERNELS_ASM = [ + "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm", @@ -143,16 +145,6 @@ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", ] -# buildifier: keep sorted -DOTPROD_BF16_KERNELS_ASM = [ - "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", -] - -# buildifier: keep sorted -I8MM_BF16_KERNELS_ASM = [ - "matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", -] - # buildifier: keep sorted SME_KERNELS = [ "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", @@ -322,33 +314,17 @@ kai_c_library( textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS_ASM], ) -kai_c_library( - name = "i8mm_bf16_impl_asm", - srcs = [ukernel + "_asm.S" for ukernel in I8MM_BF16_KERNELS_ASM] + [ukernel + ".c" for ukernel in I8MM_BF16_KERNELS_ASM], - cpu_uarch = kai_cpu_i8mm() + kai_cpu_bf16(), - textual_hdrs = [ukernel + ".h" for ukernel in I8MM_BF16_KERNELS_ASM], -) - -kai_c_library( - name = "dotprod_bf16_impl_asm", - srcs = [ukernel + "_asm.S" for ukernel in DOTPROD_BF16_KERNELS_ASM] + [ukernel + ".c" for ukernel in DOTPROD_BF16_KERNELS_ASM], - cpu_uarch = kai_cpu_dotprod() + kai_cpu_bf16(), - textual_hdrs = [ukernel + ".h" for ukernel in DOTPROD_BF16_KERNELS_ASM], -) - kai_c_library( name = "matmul", visibility = ["//visibility:public"], deps = [ ":bf16_impl", - ":dotprod_bf16_impl_asm", ":dotprod_impl", ":dotprod_impl_asm", ":fp16_bf16_impl", ":fp16_dotprod_impl_asm", ":fp16_i8mm_impl_asm", ":fp16_impl", - ":i8mm_bf16_impl_asm", ":i8mm_impl", ":i8mm_impl_asm", ":interface", diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c index 8b356de7..6c295abb 100644 --- a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c @@ -3,9 +3,8 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ - !defined(_M_ARM64) -#error "Dotprod extension and bf16 vector arithmetic required to compile this micro-kernel" +#if (!defined(__aarch64__) && !defined(_M_ARM64)) || !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. #include "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h" @@ -16,7 +15,7 @@ #include "kai/kai_common.h" typedef struct { - uint16_t* dst; + void* dst; const void* lhs_packed; const void* rhs_packed; const float* clamp_vals; diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c index d4ac78fb..1072131f 100644 --- a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.c @@ -3,9 +3,8 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ - !defined(_M_ARM64) -#error "I8mm extension and bf16 vector arithmetic required to compile this micro-kernel" +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) +#error "I8mm extension required to compile this micro-kernel" #else // Architectural features check. #include "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h" @@ -16,7 +15,7 @@ #include "kai/kai_common.h" typedef struct { - uint16_t* dst; + void* dst; const void* lhs_packed; const void* rhs_packed; const float* clamp_vals; diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c index faaf58b9..20ff6424 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c @@ -3,11 +3,8 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if ( \ - !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || \ - !defined(__ARM_FEATURE_BF16_SCALAR_ARITHMETIC)) && \ - !defined(_M_ARM64) -#error This file must be compiled for AArch64, FEAT_BF16. +#if (!defined(__aarch64__) && !defined(_M_ARM64)) +#error This file must be compiled for AArch64. #else // Architectural features check. #include "kai_lhs_quant_pack_qai8dxp_bf16_neon.h" @@ -16,6 +13,7 @@ #endif #include #include +#include #include #include "kai/kai_common.h" @@ -79,99 +77,409 @@ void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( const size_t dst_stride = kai_lhs_packed_stride(k, mr); const size_t k_internal = kai_k_roundedup(k); const int32_t k_block_len = (int32_t)(kr / sr); + KAI_ASSERT(k_block_len == 8); const int32_t num_blocks_k = (int32_t)(k / k_block_len); const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - float max0 = -FLT_MAX; - float min0 = FLT_MAX; - - // Find min/max for each channel - int32_t k_idx = 0; - float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); - uint16x8_t zero = vdupq_n_u16(0); - // Process 8 bfloat16 values per iteration. - for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { - // // Load eight bfloat16 values. - uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); - uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); - uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); - float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); - float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); - - // Calculate the maximum - vmax0 = vmaxq_f32(src0_0, vmax0); - vmax0 = vmaxq_f32(vmax0, src0_1); - - // Calculate the minimum - vmin0 = vminq_f32(src0_0, vmin0); - vmin0 = vminq_f32(vmin0, src0_1); - } - // Get the max/min scalar values. - max0 = vmaxvq_f32(vmax0); - min0 = vminvq_f32(vmin0); - // Process leftover elements with a scalar loop. - for (; k_idx < (int32_t)k; ++k_idx) { - const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); - max0 = fmaxf(src0_0, max0); - min0 = fminf(src0_0, min0); + if (mr == 4) { + for (size_t row_idx = 0; row_idx < m; row_idx += 4) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + float max1 = -FLT_MAX; + float min1 = FLT_MAX; + float max2 = -FLT_MAX; + float min2 = FLT_MAX; + float max3 = -FLT_MAX; + float min3 = FLT_MAX; + + // Find min/max for each channel + int32_t k_idx = 0; + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + float32x4_t vmax1 = vmax0; + float32x4_t vmin1 = vmin0; + float32x4_t vmax2 = vmax0; + float32x4_t vmin2 = vmin0; + float32x4_t vmax3 = vmax0; + float32x4_t vmin3 = vmin0; + const uint16x8_t zero = vdupq_n_u16(0); + // Process 8 bfloat16 values per iteration. + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + // Load eight bfloat16 values. + const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx); + const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t))); + const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t)))); + const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t)))); + + const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0); + const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0); + const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1); + const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1); + const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2); + const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2); + const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3); + const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3); + + const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0); + const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0); + const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1); + const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1); + const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2); + const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2); + const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3); + const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3); + + // Calculate the maximum + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + vmax1 = vmaxq_f32(src1_0, vmax1); + vmax1 = vmaxq_f32(vmax1, src1_1); + vmax2 = vmaxq_f32(src2_0, vmax2); + vmax2 = vmaxq_f32(vmax2, src2_1); + vmax3 = vmaxq_f32(src3_0, vmax3); + vmax3 = vmaxq_f32(vmax3, src3_1); + + // Calculate the minimum + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + vmin1 = vminq_f32(src1_0, vmin1); + vmin1 = vminq_f32(vmin1, src1_1); + vmin2 = vminq_f32(src2_0, vmin2); + vmin2 = vminq_f32(vmin2, src2_1); + vmin3 = vminq_f32(src3_0, vmin3); + vmin3 = vminq_f32(vmin3, src3_1); + } + // Get the max/min scalar values. + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); + max1 = vmaxvq_f32(vmax1); + min1 = vminvq_f32(vmin1); + max2 = vmaxvq_f32(vmax2); + min2 = vminvq_f32(vmin2); + max3 = vmaxvq_f32(vmax3); + min3 = vminvq_f32(vmin3); + // Process leftover elements with a scalar loop. + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); + max0 = fmaxf(src0, max0); + min0 = fminf(src0, min0); + const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t)))); + max1 = fmaxf(src1, max1); + min1 = fminf(src1, min1); + const float src2 = kai_cast_f32_bf16(*(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t))))); + max2 = fmaxf(src2, max2); + min2 = fminf(src2, min2); + const float src3 = kai_cast_f32_bf16(*(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t))))); + max3 = fmaxf(src3, max3); + min3 = fminf(src3, min3); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = fminf(0.0F, min0); + const float rmax0 = fmaxf(0.0F, max0); + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + const float rmin1 = fminf(0.0F, min1); + const float rmax1 = fmaxf(0.0F, max1); + const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1); + const float rmin2 = fminf(0.0F, min2); + const float rmax2 = fmaxf(0.0F, max2); + const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2); + const float rmin3 = fminf(0.0F, min3); + const float rmax3 = fmaxf(0.0F, max3); + const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F; + const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F; + const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + const float descaled_min1 = rmin1 * scale1; + const float descaled_max1 = rmax1 * scale1; + const float descaled_min2 = rmin2 * scale2; + const float descaled_max2 = rmax2 * scale2; + const float descaled_min3 = rmin3 * scale3; + const float descaled_max3 = rmax3 * scale3; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + const float zero_point_from_min_error1 = qmin + descaled_min1; + const float zero_point_from_max_error1 = qmax + descaled_max1; + const float zero_point_from_min_error2 = qmin + descaled_min2; + const float zero_point_from_max_error2 = qmax + descaled_max2; + const float zero_point_from_min_error3 = qmin + descaled_min3; + const float zero_point_from_max_error3 = qmax + descaled_max3; + + float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 + : qmax - descaled_max0; + float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1 + : qmax - descaled_max1; + float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2 + : qmax - descaled_max2; + float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3 + : qmax - descaled_max3; + + zero_point0 = fmaxf(zero_point0, qmin); + zero_point0 = fminf(zero_point0, qmax); + zero_point1 = fmaxf(zero_point1, qmin); + zero_point1 = fminf(zero_point1, qmax); + zero_point2 = fmaxf(zero_point2, qmin); + zero_point2 = fminf(zero_point2, qmax); + zero_point3 = fmaxf(zero_point3, qmin); + zero_point3 = fminf(zero_point3, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1); + const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2); + const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len); + + // Quantize the channels + int32_t block_idx = 0; + + for (; block_idx < num_blocks_k; ++block_idx) { + // Clamp at the last valid k-index + const int32_t k_idx_start = block_idx * k_block_len; + + // Load eight bfloat16 values and convert them to float32. + const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx_start); + const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t))); + const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t)))); + const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t)))); + const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0); + const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0); + const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1); + const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1); + const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2); + const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2); + const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3); + const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3); + const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0); + const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0); + const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1); + const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1); + const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2); + const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2); + const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3); + const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3); + + // Scale the values. + const int16x4_t v0_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_0, scale0))); + const int16x4_t v1_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_1, scale0))); + const int16x4_t v0_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_0, scale1))); + const int16x4_t v1_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_1, scale1))); + const int16x4_t v0_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_0, scale2))); + const int16x4_t v1_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_1, scale2))); + const int16x4_t v0_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_0, scale3))); + const int16x4_t v1_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_1, scale3))); + + int16x8_t v0_s16 = vcombine_s16(v0_0, v1_0); + int16x8_t v1_s16 = vcombine_s16(v0_1, v1_1); + int16x8_t v2_s16 = vcombine_s16(v0_2, v1_2); + int16x8_t v3_s16 = vcombine_s16(v0_3, v1_3); + + // Add zero points. + const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0); + const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1); + const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2); + const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3); + + v0_s16 = vaddq_s16(v0_s16, vnzp0); + v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN)); + v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX)); + v1_s16 = vaddq_s16(v1_s16, vnzp1); + v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN)); + v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX)); + v2_s16 = vaddq_s16(v2_s16, vnzp2); + v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN)); + v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX)); + v3_s16 = vaddq_s16(v3_s16, vnzp3); + v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN)); + v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX)); + + const uint8x8_t v0_s8 = vqmovn_s16(v0_s16); + const uint8x8_t v1_s8 = vqmovn_s16(v1_s16); + const uint8x8_t v2_s8 = vqmovn_s16(v2_s16); + const uint8x8_t v3_s8 = vqmovn_s16(v3_s16); + + vst1_s8((int8_t*)(dst_ptr), v0_s8); + vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8); + vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8); + vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8); + dst_ptr += 4 * sizeof(int8x8_t); + } + + for (; block_idx < num_blocks_k_internal; ++block_idx) { + // Left over k + for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index. + const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); + + const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); + const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t)))); + const float src2 = + kai_cast_f32_bf16(*(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t))))); + const float src3 = + kai_cast_f32_bf16(*(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t))))); + + // Scale the value. + int32_t v0_s32 = (int32_t)(roundf(src0 * scale0)); + int32_t v1_s32 = (int32_t)(roundf(src1 * scale1)); + int32_t v2_s32 = (int32_t)(roundf(src2 * scale2)); + int32_t v3_s32 = (int32_t)(roundf(src3 * scale3)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + + v1_s32 = v1_s32 + nudged_zero_point1; + v1_s32 = KAI_MAX(v1_s32, INT8_MIN); + v1_s32 = KAI_MIN(v1_s32, INT8_MAX); + + v2_s32 = v2_s32 + nudged_zero_point2; + v2_s32 = KAI_MAX(v2_s32, INT8_MIN); + v2_s32 = KAI_MIN(v2_s32, INT8_MAX); + + v3_s32 = v3_s32 + nudged_zero_point3; + v3_s32 = KAI_MAX(v3_s32, INT8_MIN); + v3_s32 = KAI_MIN(v3_s32, INT8_MAX); + + *(int8_t*)dst_ptr = (int8_t)v0_s32; + *(int8_t*)(dst_ptr + sizeof(int8x8_t)) = (int8_t)v1_s32; + *(int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)) = (int8_t)v2_s32; + *(int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)) = (int8_t)v3_s32; + + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row. + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1; + *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2; + *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params. + *((float*)(dst_ptr)) = recip_scale0; + *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1; + *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2; + *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3; + + // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). + src_ptr += (4 * lhs_stride / sizeof(uint16_t)); + + // Move to the next row as we have interleaved all Mr rows. + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); } + } else { + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + int32_t k_idx = 0; + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + const uint16x8_t zero = vdupq_n_u16(0); + // Process 8 bfloat16 values per iteration. + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + // Load eight bfloat16 values. + const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); + const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); + const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); + const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); + const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + + // Calculate the maximum + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the minimum + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + // Get the max/min scalar values. + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); + // Process leftover elements with a scalar loop. + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); + max0 = fmaxf(src0_0, max0); + min0 = fminf(src0_0, min0); + } - // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; - const float rmin0 = fminf(0.0F, min0); - const float rmax0 = fmaxf(0.0F, max0); - const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + const float rmin0 = fminf(0.0F, min0); + const float rmax0 = fmaxf(0.0F, max0); + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); - // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; - const float descaled_min0 = rmin0 * scale0; - const float descaled_max0 = rmax0 * scale0; + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; - const float zero_point_from_min_error0 = qmin + descaled_min0; - const float zero_point_from_max_error0 = qmax + descaled_max0; + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; - float zero_point0 = - (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0; + float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 + : qmax - descaled_max0; - zero_point0 = fmaxf(zero_point0, qmin); - zero_point0 = fminf(zero_point0, qmax); + zero_point0 = fmaxf(zero_point0, qmin); + zero_point0 = fminf(zero_point0, qmax); - // Round to nearest integer - const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); - const size_t dst_x = ((row_idx + m_idx_start) % mr); + const size_t dst_x = ((row_idx + m_idx_start) % mr); - uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); - // Quantize the channels - int32_t block_idx = 0; + // Quantize the channels + int32_t block_idx = 0; - if (k_block_len == 8) { for (; block_idx < num_blocks_k; ++block_idx) { // Clamp at the last valid k-index const int32_t k_idx_start = block_idx * k_block_len; // Load eight bfloat16 values and convert them to float32. - uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); - uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); - uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); - float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); - float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); + const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); + const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); + const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); + const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); // Scale the values. - float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); - float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); - int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); - int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); + const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); + const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); + const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); + const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); - int16x4_t v0_s16 = vqmovn_s32(v0_s32); - int16x4_t v1_s16 = vqmovn_s32(v1_s32); + const int16x4_t v0_s16 = vqmovn_s32(v0_s32); + const int16x4_t v1_s16 = vqmovn_s32(v1_s32); int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); // Add zero points. @@ -181,17 +489,18 @@ void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); - int8x8_t v0_s8 = vqmovn_s16(v_s16); + const int8x8_t v0_s8 = vqmovn_s16(v_s16); vst1_s8((int8_t*)(dst_ptr), v0_s8); dst_ptr += 8 * sizeof(int8_t); dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); } - } else { - for (; block_idx < num_blocks_k; ++block_idx) { + + for (; block_idx < num_blocks_k_internal; ++block_idx) { + // Left over k for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { - const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx; + // Clamp at the last valid k-index. + const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); - // Convert the bfloat16 value to float. const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); // Scale the value. @@ -206,50 +515,29 @@ void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( } dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); } - } - - for (; block_idx < num_blocks_k_internal; ++block_idx) { - // Left over k - for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { - // Clamp at the last valid k-index. - const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); - - const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); - - // Scale the value. - int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); - - v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = KAI_MAX(v0_s32, INT8_MIN); - v0_s32 = KAI_MIN(v0_s32, INT8_MAX); - - *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; - dst_ptr += sizeof(int8_t); - } - dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); - } - dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); - dst_ptr += dst_x * kai_num_bytes_per_offset; + dst_ptr += dst_x * kai_num_bytes_per_offset; - // LHS offset at the beginning of the row. - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + // LHS offset at the beginning of the row. + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; - // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. - KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); - dst_ptr += mr * kai_num_bytes_per_offset; + dst_ptr += mr * kai_num_bytes_per_offset; - // Store the scale quantization params. - *((float*)(dst_ptr)) = recip_scale0; + // Store the scale quantization params. + *((float*)(dst_ptr)) = recip_scale0; - // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). - src_ptr += (lhs_stride / sizeof(uint16_t)); + // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). + src_ptr += (lhs_stride / sizeof(uint16_t)); - // Move to the next row if we have interleaved all Mr rows. - if ((((row_idx + 1) + m_idx_start) % mr) == 0) { - lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + // Move to the next row if we have interleaved all Mr rows. + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } } } } diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp index 9131fdb5..dd542f43 100644 --- a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp @@ -337,7 +337,8 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{16, 32, 64}, // MatMulShape{8, 32, 64}, // MatMulShape{15, 32, 32}, // - MatMulShape{77, 99, 64}), + MatMulShape{77, 99, 64}, // + MatMulShape{77, 99, 66}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From d84dfbe9bd174fb17f86d6792cfe87c974f13309 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Fri, 4 Jul 2025 17:24:29 +0100 Subject: [PATCH 4/9] change signed vectors to correct type Signed-off-by: Evie Wright --- .../matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c index 20ff6424..0d57581d 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c @@ -307,10 +307,10 @@ void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN)); v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX)); - const uint8x8_t v0_s8 = vqmovn_s16(v0_s16); - const uint8x8_t v1_s8 = vqmovn_s16(v1_s16); - const uint8x8_t v2_s8 = vqmovn_s16(v2_s16); - const uint8x8_t v3_s8 = vqmovn_s16(v3_s16); + const int8x8_t v0_s8 = vqmovn_s16(v0_s16); + const int8x8_t v1_s8 = vqmovn_s16(v1_s16); + const int8x8_t v2_s8 = vqmovn_s16(v2_s16); + const int8x8_t v3_s8 = vqmovn_s16(v3_s16); vst1_s8((int8_t*)(dst_ptr), v0_s8); vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8); -- GitLab From 45976a771a5ad369b2e9bc74edbe7f9e595edb2a Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Fri, 4 Jul 2025 17:43:23 +0100 Subject: [PATCH 5/9] update macro guard to avoid remote compilation error Signed-off-by: Evie Wright --- ...i_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c index 6c295abb..3bc380f2 100644 --- a/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if (!defined(__aarch64__) && !defined(_M_ARM64)) || !defined(__ARM_FEATURE_DOTPROD) +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) #error "Dotprod extension required to compile this micro-kernel" #else // Architectural features check. -- GitLab From 662630b29b4c7926bafd8841b2c4f98020a4f86f Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 7 Jul 2025 11:14:00 +0100 Subject: [PATCH 6/9] ensure support for cases where m is not a multiple of mr Signed-off-by: Evie Wright --- .../kai_lhs_quant_pack_qai8dxp_bf16_neon.c | 234 +++++++++--------- 1 file changed, 118 insertions(+), 116 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c index 0d57581d..8c1199dd 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c @@ -82,8 +82,10 @@ void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( const int32_t num_blocks_k = (int32_t)(k / k_block_len); const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); + size_t row_idx = 0; + if (mr == 4) { - for (size_t row_idx = 0; row_idx < m; row_idx += 4) { + for (; row_idx + 3 < m; row_idx += 4) { float max0 = -FLT_MAX; float min0 = FLT_MAX; float max1 = -FLT_MAX; @@ -391,153 +393,153 @@ void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( // Move to the next row as we have interleaved all Mr rows. lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); } - } else { - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - float max0 = -FLT_MAX; - float min0 = FLT_MAX; + } - // Find min/max for each channel - int32_t k_idx = 0; - float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); - const uint16x8_t zero = vdupq_n_u16(0); - // Process 8 bfloat16 values per iteration. - for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { - // Load eight bfloat16 values. - const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); - const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); - const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); - const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); - const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + for (; row_idx < m; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + int32_t k_idx = 0; + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + const uint16x8_t zero = vdupq_n_u16(0); + // Process 8 bfloat16 values per iteration. + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + // Load eight bfloat16 values. + const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); + const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); + const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); + const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); + const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + + // Calculate the maximum + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the minimum + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + // Get the max/min scalar values. + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); + // Process leftover elements with a scalar loop. + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); + max0 = fmaxf(src0_0, max0); + min0 = fminf(src0_0, min0); + } - // Calculate the maximum - vmax0 = vmaxq_f32(src0_0, vmax0); - vmax0 = vmaxq_f32(vmax0, src0_1); + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; - // Calculate the minimum - vmin0 = vminq_f32(src0_0, vmin0); - vmin0 = vminq_f32(vmin0, src0_1); - } - // Get the max/min scalar values. - max0 = vmaxvq_f32(vmax0); - min0 = vminvq_f32(vmin0); - // Process leftover elements with a scalar loop. - for (; k_idx < (int32_t)k; ++k_idx) { - const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); - max0 = fmaxf(src0_0, max0); - min0 = fminf(src0_0, min0); - } + const float rmin0 = fminf(0.0F, min0); + const float rmax0 = fmaxf(0.0F, max0); + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); - // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; - const float rmin0 = fminf(0.0F, min0); - const float rmax0 = fmaxf(0.0F, max0); - const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; - // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; - const float descaled_min0 = rmin0 * scale0; - const float descaled_max0 = rmax0 * scale0; + float zero_point0 = + (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0; - const float zero_point_from_min_error0 = qmin + descaled_min0; - const float zero_point_from_max_error0 = qmax + descaled_max0; + zero_point0 = fmaxf(zero_point0, qmin); + zero_point0 = fminf(zero_point0, qmax); - float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 - : qmax - descaled_max0; + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); - zero_point0 = fmaxf(zero_point0, qmin); - zero_point0 = fminf(zero_point0, qmax); + const size_t dst_x = ((row_idx + m_idx_start) % mr); - // Round to nearest integer - const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); - const size_t dst_x = ((row_idx + m_idx_start) % mr); + // Quantize the channels + int32_t block_idx = 0; - uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); + for (; block_idx < num_blocks_k; ++block_idx) { + // Clamp at the last valid k-index + const int32_t k_idx_start = block_idx * k_block_len; - // Quantize the channels - int32_t block_idx = 0; + // Load eight bfloat16 values and convert them to float32. + const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); + const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); + const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); + const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); + const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); - for (; block_idx < num_blocks_k; ++block_idx) { - // Clamp at the last valid k-index - const int32_t k_idx_start = block_idx * k_block_len; + // Scale the values. + const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); + const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); + const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); + const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); - // Load eight bfloat16 values and convert them to float32. - const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); - const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); - const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); - const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); - const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); + const int16x4_t v0_s16 = vqmovn_s32(v0_s32); + const int16x4_t v1_s16 = vqmovn_s32(v1_s32); + int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); - // Scale the values. - const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); - const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); - const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); - const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); + // Add zero points. + int16_t nzp_s16 = (int16_t)nudged_zero_point0; + int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); + v_s16 = vaddq_s16(v_s16, vnzp_s16); + v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); + v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); - const int16x4_t v0_s16 = vqmovn_s32(v0_s32); - const int16x4_t v1_s16 = vqmovn_s32(v1_s32); - int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); + const int8x8_t v0_s8 = vqmovn_s16(v_s16); + vst1_s8((int8_t*)(dst_ptr), v0_s8); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } - // Add zero points. - int16_t nzp_s16 = (int16_t)nudged_zero_point0; - int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); - v_s16 = vaddq_s16(v_s16, vnzp_s16); - v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); - v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); + for (; block_idx < num_blocks_k_internal; ++block_idx) { + // Left over k + for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index. + const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); - const int8x8_t v0_s8 = vqmovn_s16(v_s16); - vst1_s8((int8_t*)(dst_ptr), v0_s8); - dst_ptr += 8 * sizeof(int8_t); - dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); - } + const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); - for (; block_idx < num_blocks_k_internal; ++block_idx) { - // Left over k - for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { - // Clamp at the last valid k-index. - const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); + // Scale the value. + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); - const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); - // Scale the value. - int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); - - v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = KAI_MAX(v0_s32, INT8_MIN); - v0_s32 = KAI_MIN(v0_s32, INT8_MAX); - - *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; - dst_ptr += sizeof(int8_t); - } - dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } - dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); - dst_ptr += dst_x * kai_num_bytes_per_offset; + dst_ptr += dst_x * kai_num_bytes_per_offset; - // LHS offset at the beginning of the row. - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + // LHS offset at the beginning of the row. + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; - // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. - KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); - dst_ptr += mr * kai_num_bytes_per_offset; + dst_ptr += mr * kai_num_bytes_per_offset; - // Store the scale quantization params. - *((float*)(dst_ptr)) = recip_scale0; + // Store the scale quantization params. + *((float*)(dst_ptr)) = recip_scale0; - // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). - src_ptr += (lhs_stride / sizeof(uint16_t)); + // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). + src_ptr += (lhs_stride / sizeof(uint16_t)); - // Move to the next row if we have interleaved all Mr rows. - if ((((row_idx + 1) + m_idx_start) % mr) == 0) { - lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); - } + // Move to the next row if we have interleaved all Mr rows. + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); } } } -- GitLab From c248525f85c676b61478246a9f19836761844c0b Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Mon, 7 Jul 2025 15:06:02 +0100 Subject: [PATCH 7/9] remove redundant cmakelists Signed-off-by: Evie Wright --- CMakeLists.txt | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1656e376..ad3dfa1d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -354,9 +354,6 @@ else() target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME_ASM}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2_ASM}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_DOTPROD}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_I8MM}) - target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16_ASM}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS /arch:armv8.0${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) @@ -364,18 +361,13 @@ else() set_source_files_properties(${KLEIDIAI_FILES_NEON_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_SME_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_SME2_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_DOTPROD} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_I8MM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) - set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16_ASM} PROPERTIES COMPILE_OPTIONS /arch:armv8.2${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set(KLEIDIAI_FILES_ASM ${KLEIDIAI_FILES_SME_ASM} ${KLEIDIAI_FILES_SME2_ASM} ${KLEIDIAI_FILES_NEON_ASM} ${KLEIDIAI_FILES_NEON_DOTPROD_ASM} - ${KLEIDIAI_FILES_NEON_I8MM_ASM} - ${KLEIDIAI_FILES_NEON_BF16_DOTPROD_ASM} - ${KLEIDIAI_FILES_NEON_BF16_I8MM_ASM}) + ${KLEIDIAI_FILES_NEON_I8MM_ASM}) list(FILTER KLEIDIAI_FILES_ASM INCLUDE REGEX "^.*\.S$") set_source_files_properties(${KLEIDIAI_FILES_ASM} PROPERTIES LANGUAGE ASM_MARMASM) -- GitLab From 8952ba9748a2c5f325714208ef6611c456c64655 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 8 Jul 2025 11:49:59 +0100 Subject: [PATCH 8/9] restore correct whitespace in cmakelists, add extra test with k=31 Signed-off-by: Evie Wright --- CMakeLists.txt | 2 +- test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ad3dfa1d..a74e779a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -335,7 +335,6 @@ if(NOT MSVC) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME}) target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME2}) - set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) @@ -345,6 +344,7 @@ if(NOT MSVC) set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH}) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH}) + # Use -fno-tree-vectorize option to disable compiler based vectorization set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") set_source_files_properties(${KLEIDIAI_FILES_SME2} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH}") diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp index dd542f43..7448516c 100644 --- a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp @@ -338,7 +338,8 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{8, 32, 64}, // MatMulShape{15, 32, 32}, // MatMulShape{77, 99, 64}, // - MatMulShape{77, 99, 66}), + MatMulShape{77, 99, 66}, // + MatMulShape{77, 99, 31}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From de6353c9bfd4b9cbceb3cfe4677062ed65529d2f Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 8 Jul 2025 12:05:04 +0100 Subject: [PATCH 9/9] lower requirements for tests in line with changed guards Signed-off-by: Evie Wright --- test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp index 7448516c..a7fdb6f4 100644 --- a/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp @@ -47,9 +47,9 @@ namespace kai::test { static const std::array, 2> variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp = {{ {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod), - "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod_and_bf16}, + "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod}, {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm), - "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm_and_bf16}, + "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm}, }}; class MatMulTest_bf16_qai8dxp_qsi4cxp : public ::testing::TestWithParam {}; -- GitLab