diff --git a/CHANGELOG.md b/CHANGELOG.md index 123686b079415bffcad3106cba678872983a1c12..d0b00dc65f98f762d9754c45c6e7af45913d3385 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification for releases. +## v0.4.0 -- Upcoming Release + +- Micro-kernels to compute the matrix multiplication of dynamically quantized 8-bit integer (QAI8DX) LHS matrix, which typically holds the neural network activations, and quantized 4-bit integer (QSI4CX) RHS matrix, which typically holds the neural network weights, and the accumulation of the result into a single-precision (F32) output, optimized using the ArmĀ® CPU feature FEAT_DotProd. + ## v0.3.0 - Advanced SIMD FP32 GEMM micro-kernel. diff --git a/CMakeLists.txt b/CMakeLists.txt index 9dc92e3a030510ec24d9b75ab565b10854a46a3d..ee35f8e6d04bd4623c64e4bda9808924d8fa987f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,8 @@ set(KLEIDIAI_FILES_NEON set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c diff --git a/benchmark/matmul/matmul_f32.cpp b/benchmark/matmul/matmul_f32.cpp index e956a7c58c4b7e277c5f59be1b7fd26c30e68dd8..bc6b712eff7cbed458857946909366bc6b3fa1bf 100644 --- a/benchmark/matmul/matmul_f32.cpp +++ b/benchmark/matmul/matmul_f32.cpp @@ -16,8 +16,10 @@ #include "benchmark/matmul/matmul_utils.hpp" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" @@ -154,6 +156,30 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod"}, }; void RegisterBenchmarks(size_t m, size_t n, size_t k) { diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt index 8993c16f721984c80d8db78410361626b603cf01..02dd1ede8644b195de599fe417bfaf40c770a56d 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt @@ -28,7 +28,10 @@ add_executable(matmul_clamp_f32_qai8dxp_qsi4cxp ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c - ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c + ) # Compile with DotProd and I8MM features enabled target_compile_options(matmul_clamp_f32_qai8dxp_qsi4cxp PRIVATE -march=armv8.2-a+dotprod+i8mm) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp index fe75aa5885589c793ed834fadfa40e29d301726c..057f8eb58c143020c75de6cc15ad00f070860f11 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -18,8 +18,10 @@ #include "kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" @@ -113,6 +115,30 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod"}, }; diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 67aa0d9ec8a9748f6c9d6a939541a72d7fbbc3b0..6a08810752ce584f49cb6f460cea41b421451544 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -71,6 +71,26 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + kai_c_library( name = "clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c"], @@ -281,8 +301,10 @@ kai_c_library( ":clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod", ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", ":clamp_f32_qai8dxp_qsi4c32p_interface", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..130b86aa0af9234ef4e4a3972bbf9b1c4ce266a0 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.c @@ -0,0 +1,754 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else // Architectural features check. +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + __asm__ __volatile__( + "mov x13, %x[m]\n" + "mov x12, #0x80\n" + "mov x20, #0x20\n" + "cmp x13, #0x10\n" + "madd x12, %x[num_blocks], x12, x20\n" + "blt 14f\n" + "1:" // Row loop + "mov x11, %x[rhs_packed]\n" + "mov x10, %x[n]\n" + "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" + "2:" // Column loop + "mov x27, %x[lhs_packed]\n" + "movi v31.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "mov x23, %x[num_blocks]\n" + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "add x22, x27, x12\n" + "add x21, x22, x12\n" + "add x20, x21, x12\n" + "movi v25.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "3:" // Sub block loop + "ldr q13, [x11, #0x0]\n" + "ldr q14, [x27, #0x0]\n" + "movi v10.16b, #0xf0\n" + "subs x23, x23, #0x1\n" + "ldr q6, [x22, #0x0]\n" + "ldr q15, [x21, #0x0]\n" + "ldr q3, [x20, #0x0]\n" + "ldr q12, [x11, #0x10]\n" + "ldr q8, [x27, #0x10]\n" + "ldr q4, [x22, #0x10]\n" + "shl v9.16b, v13.16b, #0x4\n" + "and v13.16b, v13.16b, v10.16b\n" + "ldr q0, [x21, #0x10]\n" + "ldr q1, [x20, #0x10]\n" + "ldr q5, [x11, #0x20]\n" + "ldr q2, [x27, #0x20]\n" + "shl v7.16b, v12.16b, #0x4\n" + "and v12.16b, v12.16b, v10.16b\n" + "ldr q11, [x22, #0x20]\n" + ".inst 0x4f8ee13f // sdot v31.4s, v9.16b, v14.4b[0]\n" + ".inst 0x4faee13e // sdot v30.4s, v9.16b, v14.4b[1]\n" + ".inst 0x4f8ee93d // sdot v29.4s, v9.16b, v14.4b[2]\n" + ".inst 0x4faee93c // sdot v28.4s, v9.16b, v14.4b[3]\n" + "ldr q14, [x21, #0x20]\n" + ".inst 0x4f86e13b // sdot v27.4s, v9.16b, v6.4b[0]\n" + ".inst 0x4fa6e13a // sdot v26.4s, v9.16b, v6.4b[1]\n" + ".inst 0x4f86e939 // sdot v25.4s, v9.16b, v6.4b[2]\n" + ".inst 0x4fa6e938 // sdot v24.4s, v9.16b, v6.4b[3]\n" + "ldr q6, [x20, #0x20]\n" + ".inst 0x4f8fe137 // sdot v23.4s, v9.16b, v15.4b[0]\n" + ".inst 0x4fafe136 // sdot v22.4s, v9.16b, v15.4b[1]\n" + ".inst 0x4f8fe935 // sdot v21.4s, v9.16b, v15.4b[2]\n" + ".inst 0x4fafe934 // sdot v20.4s, v9.16b, v15.4b[3]\n" + "ldr q15, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + ".inst 0x4f83e133 // sdot v19.4s, v9.16b, v3.4b[0]\n" + ".inst 0x4fa3e132 // sdot v18.4s, v9.16b, v3.4b[1]\n" + ".inst 0x4f83e931 // sdot v17.4s, v9.16b, v3.4b[2]\n" + ".inst 0x4fa3e930 // sdot v16.4s, v9.16b, v3.4b[3]\n" + "ldr q9, [x27, #0x30]\n" + "ldr q3, [x22, #0x30]\n" + ".inst 0x4f88e0ff // sdot v31.4s, v7.16b, v8.4b[0]\n" + ".inst 0x4fa8e0fe // sdot v30.4s, v7.16b, v8.4b[1]\n" + ".inst 0x4f88e8fd // sdot v29.4s, v7.16b, v8.4b[2]\n" + ".inst 0x4fa8e8fc // sdot v28.4s, v7.16b, v8.4b[3]\n" + "ldr q8, [x21, #0x30]\n" + ".inst 0x4f84e0fb // sdot v27.4s, v7.16b, v4.4b[0]\n" + ".inst 0x4fa4e0fa // sdot v26.4s, v7.16b, v4.4b[1]\n" + ".inst 0x4f84e8f9 // sdot v25.4s, v7.16b, v4.4b[2]\n" + ".inst 0x4fa4e8f8 // sdot v24.4s, v7.16b, v4.4b[3]\n" + "ldr q4, [x20, #0x30]\n" + ".inst 0x4f80e0f7 // sdot v23.4s, v7.16b, v0.4b[0]\n" + ".inst 0x4fa0e0f6 // sdot v22.4s, v7.16b, v0.4b[1]\n" + ".inst 0x4f80e8f5 // sdot v21.4s, v7.16b, v0.4b[2]\n" + ".inst 0x4fa0e8f4 // sdot v20.4s, v7.16b, v0.4b[3]\n" + "ldr q0, [x27, #0x40]\n" + ".inst 0x4f81e0f3 // sdot v19.4s, v7.16b, v1.4b[0]\n" + ".inst 0x4fa1e0f2 // sdot v18.4s, v7.16b, v1.4b[1]\n" + ".inst 0x4f81e8f1 // sdot v17.4s, v7.16b, v1.4b[2]\n" + ".inst 0x4fa1e8f0 // sdot v16.4s, v7.16b, v1.4b[3]\n" + "ldr q1, [x22, #0x40]\n" + "shl v7.16b, v5.16b, #0x4\n" + "and v5.16b, v5.16b, v10.16b\n" + ".inst 0x4f82e0ff // sdot v31.4s, v7.16b, v2.4b[0]\n" + ".inst 0x4fa2e0fe // sdot v30.4s, v7.16b, v2.4b[1]\n" + ".inst 0x4f82e8fd // sdot v29.4s, v7.16b, v2.4b[2]\n" + ".inst 0x4fa2e8fc // sdot v28.4s, v7.16b, v2.4b[3]\n" + "ldr q2, [x21, #0x40]\n" + ".inst 0x4f8be0fb // sdot v27.4s, v7.16b, v11.4b[0]\n" + ".inst 0x4fabe0fa // sdot v26.4s, v7.16b, v11.4b[1]\n" + ".inst 0x4f8be8f9 // sdot v25.4s, v7.16b, v11.4b[2]\n" + ".inst 0x4fabe8f8 // sdot v24.4s, v7.16b, v11.4b[3]\n" + "ldr q11, [x20, #0x40]\n" + ".inst 0x4f8ee0f7 // sdot v23.4s, v7.16b, v14.4b[0]\n" + ".inst 0x4faee0f6 // sdot v22.4s, v7.16b, v14.4b[1]\n" + ".inst 0x4f8ee8f5 // sdot v21.4s, v7.16b, v14.4b[2]\n" + ".inst 0x4faee8f4 // sdot v20.4s, v7.16b, v14.4b[3]\n" + "ldr q14, [x27, #0x50]\n" + ".inst 0x4f86e0f3 // sdot v19.4s, v7.16b, v6.4b[0]\n" + ".inst 0x4fa6e0f2 // sdot v18.4s, v7.16b, v6.4b[1]\n" + ".inst 0x4f86e8f1 // sdot v17.4s, v7.16b, v6.4b[2]\n" + ".inst 0x4fa6e8f0 // sdot v16.4s, v7.16b, v6.4b[3]\n" + "ldr q6, [x22, #0x50]\n" + "shl v7.16b, v15.16b, #0x4\n" + "and v15.16b, v15.16b, v10.16b\n" + "ldr q10, [x21, #0x50]\n" + ".inst 0x4f89e0ff // sdot v31.4s, v7.16b, v9.4b[0]\n" + ".inst 0x4fa9e0fe // sdot v30.4s, v7.16b, v9.4b[1]\n" + ".inst 0x4f89e8fd // sdot v29.4s, v7.16b, v9.4b[2]\n" + ".inst 0x4fa9e8fc // sdot v28.4s, v7.16b, v9.4b[3]\n" + "ldr q9, [x20, #0x50]\n" + ".inst 0x4f83e0fb // sdot v27.4s, v7.16b, v3.4b[0]\n" + ".inst 0x4fa3e0fa // sdot v26.4s, v7.16b, v3.4b[1]\n" + ".inst 0x4f83e8f9 // sdot v25.4s, v7.16b, v3.4b[2]\n" + ".inst 0x4fa3e8f8 // sdot v24.4s, v7.16b, v3.4b[3]\n" + "ldr q3, [x27, #0x60]\n" + ".inst 0x4f88e0f7 // sdot v23.4s, v7.16b, v8.4b[0]\n" + ".inst 0x4fa8e0f6 // sdot v22.4s, v7.16b, v8.4b[1]\n" + ".inst 0x4f88e8f5 // sdot v21.4s, v7.16b, v8.4b[2]\n" + ".inst 0x4fa8e8f4 // sdot v20.4s, v7.16b, v8.4b[3]\n" + "ldr q8, [x22, #0x60]\n" + ".inst 0x4f84e0f3 // sdot v19.4s, v7.16b, v4.4b[0]\n" + ".inst 0x4fa4e0f2 // sdot v18.4s, v7.16b, v4.4b[1]\n" + ".inst 0x4f84e8f1 // sdot v17.4s, v7.16b, v4.4b[2]\n" + ".inst 0x4fa4e8f0 // sdot v16.4s, v7.16b, v4.4b[3]\n" + "ldr q7, [x21, #0x60]\n" + "ldr q4, [x20, #0x60]\n" + ".inst 0x4f80e1bf // sdot v31.4s, v13.16b, v0.4b[0]\n" + ".inst 0x4fa0e1be // sdot v30.4s, v13.16b, v0.4b[1]\n" + ".inst 0x4f80e9bd // sdot v29.4s, v13.16b, v0.4b[2]\n" + ".inst 0x4fa0e9bc // sdot v28.4s, v13.16b, v0.4b[3]\n" + "ldr q0, [x27, #0x70]\n" + "add x27, x27, #0x80\n" + ".inst 0x4f81e1bb // sdot v27.4s, v13.16b, v1.4b[0]\n" + ".inst 0x4fa1e1ba // sdot v26.4s, v13.16b, v1.4b[1]\n" + ".inst 0x4f81e9b9 // sdot v25.4s, v13.16b, v1.4b[2]\n" + ".inst 0x4fa1e9b8 // sdot v24.4s, v13.16b, v1.4b[3]\n" + "ldr q1, [x22, #0x70]\n" + "add x22, x22, #0x80\n" + ".inst 0x4f82e1b7 // sdot v23.4s, v13.16b, v2.4b[0]\n" + ".inst 0x4fa2e1b6 // sdot v22.4s, v13.16b, v2.4b[1]\n" + ".inst 0x4f82e9b5 // sdot v21.4s, v13.16b, v2.4b[2]\n" + ".inst 0x4fa2e9b4 // sdot v20.4s, v13.16b, v2.4b[3]\n" + "ldr q2, [x21, #0x70]\n" + "add x21, x21, #0x80\n" + ".inst 0x4f8be1b3 // sdot v19.4s, v13.16b, v11.4b[0]\n" + ".inst 0x4fabe1b2 // sdot v18.4s, v13.16b, v11.4b[1]\n" + ".inst 0x4f8be9b1 // sdot v17.4s, v13.16b, v11.4b[2]\n" + ".inst 0x4fabe9b0 // sdot v16.4s, v13.16b, v11.4b[3]\n" + "ldr q11, [x20, #0x70]\n" + "add x20, x20, #0x80\n" + ".inst 0x4f8ee19f // sdot v31.4s, v12.16b, v14.4b[0]\n" + ".inst 0x4faee19e // sdot v30.4s, v12.16b, v14.4b[1]\n" + ".inst 0x4f8ee99d // sdot v29.4s, v12.16b, v14.4b[2]\n" + ".inst 0x4faee99c // sdot v28.4s, v12.16b, v14.4b[3]\n" + ".inst 0x4f86e19b // sdot v27.4s, v12.16b, v6.4b[0]\n" + ".inst 0x4fa6e19a // sdot v26.4s, v12.16b, v6.4b[1]\n" + ".inst 0x4f86e999 // sdot v25.4s, v12.16b, v6.4b[2]\n" + ".inst 0x4fa6e998 // sdot v24.4s, v12.16b, v6.4b[3]\n" + ".inst 0x4f8ae197 // sdot v23.4s, v12.16b, v10.4b[0]\n" + ".inst 0x4faae196 // sdot v22.4s, v12.16b, v10.4b[1]\n" + ".inst 0x4f8ae995 // sdot v21.4s, v12.16b, v10.4b[2]\n" + ".inst 0x4faae994 // sdot v20.4s, v12.16b, v10.4b[3]\n" + ".inst 0x4f89e193 // sdot v19.4s, v12.16b, v9.4b[0]\n" + ".inst 0x4fa9e192 // sdot v18.4s, v12.16b, v9.4b[1]\n" + ".inst 0x4f89e991 // sdot v17.4s, v12.16b, v9.4b[2]\n" + ".inst 0x4fa9e990 // sdot v16.4s, v12.16b, v9.4b[3]\n" + ".inst 0x4f83e0bf // sdot v31.4s, v5.16b, v3.4b[0]\n" + ".inst 0x4fa3e0be // sdot v30.4s, v5.16b, v3.4b[1]\n" + ".inst 0x4f83e8bd // sdot v29.4s, v5.16b, v3.4b[2]\n" + ".inst 0x4fa3e8bc // sdot v28.4s, v5.16b, v3.4b[3]\n" + ".inst 0x4f88e0bb // sdot v27.4s, v5.16b, v8.4b[0]\n" + ".inst 0x4fa8e0ba // sdot v26.4s, v5.16b, v8.4b[1]\n" + ".inst 0x4f88e8b9 // sdot v25.4s, v5.16b, v8.4b[2]\n" + ".inst 0x4fa8e8b8 // sdot v24.4s, v5.16b, v8.4b[3]\n" + ".inst 0x4f87e0b7 // sdot v23.4s, v5.16b, v7.4b[0]\n" + ".inst 0x4fa7e0b6 // sdot v22.4s, v5.16b, v7.4b[1]\n" + ".inst 0x4f87e8b5 // sdot v21.4s, v5.16b, v7.4b[2]\n" + ".inst 0x4fa7e8b4 // sdot v20.4s, v5.16b, v7.4b[3]\n" + ".inst 0x4f84e0b3 // sdot v19.4s, v5.16b, v4.4b[0]\n" + ".inst 0x4fa4e0b2 // sdot v18.4s, v5.16b, v4.4b[1]\n" + ".inst 0x4f84e8b1 // sdot v17.4s, v5.16b, v4.4b[2]\n" + ".inst 0x4fa4e8b0 // sdot v16.4s, v5.16b, v4.4b[3]\n" + ".inst 0x4f80e1ff // sdot v31.4s, v15.16b, v0.4b[0]\n" + ".inst 0x4fa0e1fe // sdot v30.4s, v15.16b, v0.4b[1]\n" + ".inst 0x4f80e9fd // sdot v29.4s, v15.16b, v0.4b[2]\n" + ".inst 0x4fa0e9fc // sdot v28.4s, v15.16b, v0.4b[3]\n" + ".inst 0x4f81e1fb // sdot v27.4s, v15.16b, v1.4b[0]\n" + ".inst 0x4fa1e1fa // sdot v26.4s, v15.16b, v1.4b[1]\n" + ".inst 0x4f81e9f9 // sdot v25.4s, v15.16b, v1.4b[2]\n" + ".inst 0x4fa1e9f8 // sdot v24.4s, v15.16b, v1.4b[3]\n" + ".inst 0x4f82e1f7 // sdot v23.4s, v15.16b, v2.4b[0]\n" + ".inst 0x4fa2e1f6 // sdot v22.4s, v15.16b, v2.4b[1]\n" + ".inst 0x4f82e9f5 // sdot v21.4s, v15.16b, v2.4b[2]\n" + ".inst 0x4fa2e9f4 // sdot v20.4s, v15.16b, v2.4b[3]\n" + ".inst 0x4f8be1f3 // sdot v19.4s, v15.16b, v11.4b[0]\n" + ".inst 0x4fabe1f2 // sdot v18.4s, v15.16b, v11.4b[1]\n" + ".inst 0x4f8be9f1 // sdot v17.4s, v15.16b, v11.4b[2]\n" + ".inst 0x4fabe9f0 // sdot v16.4s, v15.16b, v11.4b[3]\n" + "bgt 3b\n" + "ldr q5, [x11, #0x0]\n" + "ld1 { v1.4s }, [x27]\n" + "add x27, x27, #0x10\n" + "ldr q4, [x11, #0x10]\n" + "ldr q0, [x27, #0x0]\n" + "add x11, x11, #0x20\n" + "mla v31.4s, v5.4s, v1.s[0]\n" + "mla v30.4s, v5.4s, v1.s[1]\n" + "mla v29.4s, v5.4s, v1.s[2]\n" + "mla v28.4s, v5.4s, v1.s[3]\n" + "fmul v3.4s, v4.4s, v0.s[0]\n" + "fmul v2.4s, v4.4s, v0.s[1]\n" + "fmul v1.4s, v4.4s, v0.s[2]\n" + "scvtf v31.4s, v31.4s\n" + "fmul v0.4s, v4.4s, v0.s[3]\n" + "scvtf v30.4s, v30.4s\n" + "scvtf v29.4s, v29.4s\n" + "scvtf v28.4s, v28.4s\n" + "fmul v31.4s, v31.4s, v3.4s\n" + "fmul v30.4s, v30.4s, v2.4s\n" + "fmul v29.4s, v29.4s, v1.4s\n" + "fmul v28.4s, v28.4s, v0.4s\n" + "ld1 { v1.4s }, [x22]\n" + "add x22, x22, #0x10\n" + "ldr q0, [x22, #0x0]\n" + "mla v27.4s, v5.4s, v1.s[0]\n" + "mla v26.4s, v5.4s, v1.s[1]\n" + "mla v25.4s, v5.4s, v1.s[2]\n" + "mla v24.4s, v5.4s, v1.s[3]\n" + "fmul v3.4s, v4.4s, v0.s[0]\n" + "fmul v2.4s, v4.4s, v0.s[1]\n" + "fmul v1.4s, v4.4s, v0.s[2]\n" + "scvtf v27.4s, v27.4s\n" + "fmul v0.4s, v4.4s, v0.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v24.4s, v24.4s\n" + "fmul v27.4s, v27.4s, v3.4s\n" + "fmul v26.4s, v26.4s, v2.4s\n" + "fmul v25.4s, v25.4s, v1.4s\n" + "fmul v24.4s, v24.4s, v0.4s\n" + "ld1 { v1.4s }, [x21]\n" + "add x21, x21, #0x10\n" + "ldr q0, [x21, #0x0]\n" + "mla v23.4s, v5.4s, v1.s[0]\n" + "mla v22.4s, v5.4s, v1.s[1]\n" + "mla v21.4s, v5.4s, v1.s[2]\n" + "mla v20.4s, v5.4s, v1.s[3]\n" + "fmul v3.4s, v4.4s, v0.s[0]\n" + "fmul v2.4s, v4.4s, v0.s[1]\n" + "fmul v1.4s, v4.4s, v0.s[2]\n" + "scvtf v23.4s, v23.4s\n" + "fmul v0.4s, v4.4s, v0.s[3]\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v23.4s, v23.4s, v3.4s\n" + "fmul v22.4s, v22.4s, v2.4s\n" + "fmul v21.4s, v21.4s, v1.4s\n" + "fmul v20.4s, v20.4s, v0.4s\n" + "ld1 { v1.4s }, [x20]\n" + "add x20, x20, #0x10\n" + "ldr q0, [x20, #0x0]\n" + "mla v19.4s, v5.4s, v1.s[0]\n" + "mla v18.4s, v5.4s, v1.s[1]\n" + "mla v17.4s, v5.4s, v1.s[2]\n" + "mla v16.4s, v5.4s, v1.s[3]\n" + "fmul v3.4s, v4.4s, v0.s[0]\n" + "fmul v2.4s, v4.4s, v0.s[1]\n" + "fmul v1.4s, v4.4s, v0.s[2]\n" + "scvtf v19.4s, v19.4s\n" + "fmul v0.4s, v4.4s, v0.s[3]\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmul v19.4s, v19.4s, v3.4s\n" + "fmul v18.4s, v18.4s, v2.4s\n" + "fmul v17.4s, v17.4s, v1.4s\n" + "fmul v16.4s, v16.4s, v0.4s\n" + "ldr q2, [x11, #0x0]\n" + "ld1r { v1.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x10, #0x4\n" + "ld1r { v0.4s }, [x20]\n" + "add x11, x11, #0x10\n" + "fadd v31.4s, v31.4s, v2.4s\n" + "fadd v30.4s, v30.4s, v2.4s\n" + "fadd v29.4s, v29.4s, v2.4s\n" + "fadd v28.4s, v28.4s, v2.4s\n" + "fadd v27.4s, v27.4s, v2.4s\n" + "fadd v26.4s, v26.4s, v2.4s\n" + "fadd v25.4s, v25.4s, v2.4s\n" + "fadd v24.4s, v24.4s, v2.4s\n" + "fadd v23.4s, v23.4s, v2.4s\n" + "fadd v22.4s, v22.4s, v2.4s\n" + "fadd v21.4s, v21.4s, v2.4s\n" + "fadd v20.4s, v20.4s, v2.4s\n" + "fadd v19.4s, v19.4s, v2.4s\n" + "fadd v18.4s, v18.4s, v2.4s\n" + "fadd v17.4s, v17.4s, v2.4s\n" + "fadd v16.4s, v16.4s, v2.4s\n" + "fmax v31.4s, v31.4s, v1.4s\n" + "fmax v30.4s, v30.4s, v1.4s\n" + "fmax v29.4s, v29.4s, v1.4s\n" + "fmax v28.4s, v28.4s, v1.4s\n" + "fmax v27.4s, v27.4s, v1.4s\n" + "fmax v26.4s, v26.4s, v1.4s\n" + "fmax v25.4s, v25.4s, v1.4s\n" + "fmax v24.4s, v24.4s, v1.4s\n" + "fmax v23.4s, v23.4s, v1.4s\n" + "fmax v22.4s, v22.4s, v1.4s\n" + "fmax v21.4s, v21.4s, v1.4s\n" + "fmax v20.4s, v20.4s, v1.4s\n" + "fmax v19.4s, v19.4s, v1.4s\n" + "fmax v18.4s, v18.4s, v1.4s\n" + "fmax v17.4s, v17.4s, v1.4s\n" + "fmax v16.4s, v16.4s, v1.4s\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmin v30.4s, v30.4s, v0.4s\n" + "fmin v29.4s, v29.4s, v0.4s\n" + "fmin v28.4s, v28.4s, v0.4s\n" + "fmin v27.4s, v27.4s, v0.4s\n" + "fmin v26.4s, v26.4s, v0.4s\n" + "fmin v25.4s, v25.4s, v0.4s\n" + "fmin v24.4s, v24.4s, v0.4s\n" + "fmin v23.4s, v23.4s, v0.4s\n" + "fmin v22.4s, v22.4s, v0.4s\n" + "fmin v21.4s, v21.4s, v0.4s\n" + "fmin v20.4s, v20.4s, v0.4s\n" + "fmin v19.4s, v19.4s, v0.4s\n" + "fmin v18.4s, v18.4s, v0.4s\n" + "fmin v17.4s, v17.4s, v0.4s\n" + "fmin v16.4s, v16.4s, v0.4s\n" + "blt 8f\n" + "mov x20, %x[dst]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q27, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q20, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q17, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q16, [x20, #0x0]\n" + "b 13f\n" + "8:" // Partial output + "mov x28, %x[dst]\n" + "add x26, x28, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x28, %x[dst_stride_row], LSL #1\n" + "add x21, x28, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "add x27, x23, %x[dst_stride_row]\n" + "tbz x10, #1, 9f\n" + "st1 { v24.d }[0], [x23], #0x8\n" + "st1 { v25.d }[0], [x25], #0x8\n" + "st1 { v26.d }[0], [x24], #0x8\n" + "st1 { v27.d }[0], [x26], #0x8\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x22], #0x8\n" + "st1 { v30.d }[0], [x21], #0x8\n" + "st1 { v31.d }[0], [x28], #0x8\n" + "tbz x10, #0, 10f\n" + "st1 { v24.s }[2], [x23]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v27.s }[2], [x26]\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x22]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v31.s }[2], [x28]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v24.s }[0], [x23]\n" + "st1 { v25.s }[0], [x25]\n" + "st1 { v26.s }[0], [x24]\n" + "st1 { v27.s }[0], [x26]\n" + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x22]\n" + "st1 { v30.s }[0], [x21]\n" + "st1 { v31.s }[0], [x28]\n" + "10:" // Output block 0: Done + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x27, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row], LSL #1\n" + "add x23, x27, %x[dst_stride_row]\n" + "add x22, x25, %x[dst_stride_row]\n" + "add x21, x26, %x[dst_stride_row]\n" + "add x20, x24, %x[dst_stride_row]\n" + "tbz x10, #1, 11f\n" + "st1 { v16.d }[0], [x20], #0x8\n" + "st1 { v17.d }[0], [x24], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" + "st1 { v19.d }[0], [x26], #0x8\n" + "st1 { v20.d }[0], [x22], #0x8\n" + "st1 { v21.d }[0], [x25], #0x8\n" + "st1 { v22.d }[0], [x23], #0x8\n" + "st1 { v23.d }[0], [x27], #0x8\n" + "tbz x10, #0, 12f\n" + "st1 { v16.s }[2], [x20]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v18.s }[2], [x21]\n" + "st1 { v19.s }[2], [x26]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v23.s }[2], [x27]\n" + "b 12f\n" + "11:" // Output block 1: partial_1_0 + "st1 { v16.s }[0], [x20]\n" + "st1 { v17.s }[0], [x24]\n" + "st1 { v18.s }[0], [x21]\n" + "st1 { v19.s }[0], [x26]\n" + "st1 { v20.s }[0], [x22]\n" + "st1 { v21.s }[0], [x25]\n" + "st1 { v22.s }[0], [x23]\n" + "st1 { v23.s }[0], [x27]\n" + "12:" // Output block 1: Done + "13:" // Output stage exit + "subs x10, x10, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[dst], x9\n" + "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" + "bge 1b\n" + "14:" // Row loop skip + "cbz x13, 23f\n" + "15:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "16:" // Row tail: Column loop + "mov x27, %x[lhs_packed]\n" + "movi v31.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "17:" // Row tail: Sub block loop + "ldr q4, [x26, #0x0]\n" + "ldr q3, [x27, #0x0]\n" + "movi v2.16b, #0xf0\n" + "subs x20, x20, #0x1\n" + "ldr q1, [x26, #0x10]\n" + "ldr q0, [x27, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x27, #0x20]\n" + "ldr q25, [x26, #0x30]\n" + "ldr q24, [x27, #0x30]\n" + "shl v23.16b, v4.16b, #0x4\n" + "and v4.16b, v4.16b, v2.16b\n" + "ldr q22, [x27, #0x40]\n" + "ldr q21, [x27, #0x50]\n" + "shl v20.16b, v1.16b, #0x4\n" + "and v1.16b, v1.16b, v2.16b\n" + "ldr q19, [x27, #0x60]\n" + "ldr q18, [x27, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "and v27.16b, v27.16b, v2.16b\n" + ".inst 0x4f83e2ff // sdot v31.4s, v23.16b, v3.4b[0]\n" + ".inst 0x4fa3e2fe // sdot v30.4s, v23.16b, v3.4b[1]\n" + "shl v16.16b, v25.16b, #0x4\n" + "add x26, x26, #0x40\n" + ".inst 0x4f83eafd // sdot v29.4s, v23.16b, v3.4b[2]\n" + ".inst 0x4fa3eafc // sdot v28.4s, v23.16b, v3.4b[3]\n" + "and v25.16b, v25.16b, v2.16b\n" + "add x27, x27, #0x80\n" + ".inst 0x4f80e29f // sdot v31.4s, v20.16b, v0.4b[0]\n" + ".inst 0x4fa0e29e // sdot v30.4s, v20.16b, v0.4b[1]\n" + ".inst 0x4f80ea9d // sdot v29.4s, v20.16b, v0.4b[2]\n" + ".inst 0x4fa0ea9c // sdot v28.4s, v20.16b, v0.4b[3]\n" + ".inst 0x4f9ae23f // sdot v31.4s, v17.16b, v26.4b[0]\n" + ".inst 0x4fbae23e // sdot v30.4s, v17.16b, v26.4b[1]\n" + ".inst 0x4f9aea3d // sdot v29.4s, v17.16b, v26.4b[2]\n" + ".inst 0x4fbaea3c // sdot v28.4s, v17.16b, v26.4b[3]\n" + ".inst 0x4f98e21f // sdot v31.4s, v16.16b, v24.4b[0]\n" + ".inst 0x4fb8e21e // sdot v30.4s, v16.16b, v24.4b[1]\n" + ".inst 0x4f98ea1d // sdot v29.4s, v16.16b, v24.4b[2]\n" + ".inst 0x4fb8ea1c // sdot v28.4s, v16.16b, v24.4b[3]\n" + ".inst 0x4f96e09f // sdot v31.4s, v4.16b, v22.4b[0]\n" + ".inst 0x4fb6e09e // sdot v30.4s, v4.16b, v22.4b[1]\n" + ".inst 0x4f96e89d // sdot v29.4s, v4.16b, v22.4b[2]\n" + ".inst 0x4fb6e89c // sdot v28.4s, v4.16b, v22.4b[3]\n" + ".inst 0x4f95e03f // sdot v31.4s, v1.16b, v21.4b[0]\n" + ".inst 0x4fb5e03e // sdot v30.4s, v1.16b, v21.4b[1]\n" + ".inst 0x4f95e83d // sdot v29.4s, v1.16b, v21.4b[2]\n" + ".inst 0x4fb5e83c // sdot v28.4s, v1.16b, v21.4b[3]\n" + ".inst 0x4f93e37f // sdot v31.4s, v27.16b, v19.4b[0]\n" + ".inst 0x4fb3e37e // sdot v30.4s, v27.16b, v19.4b[1]\n" + ".inst 0x4f93eb7d // sdot v29.4s, v27.16b, v19.4b[2]\n" + ".inst 0x4fb3eb7c // sdot v28.4s, v27.16b, v19.4b[3]\n" + ".inst 0x4f92e33f // sdot v31.4s, v25.16b, v18.4b[0]\n" + ".inst 0x4fb2e33e // sdot v30.4s, v25.16b, v18.4b[1]\n" + ".inst 0x4f92eb3d // sdot v29.4s, v25.16b, v18.4b[2]\n" + ".inst 0x4fb2eb3c // sdot v28.4s, v25.16b, v18.4b[3]\n" + "bgt 17b\n" + "ldr q18, [x26, #0x0]\n" + "ld1 { v17.4s }, [x27]\n" + "add x27, x27, #0x10\n" + "ldr q20, [x26, #0x10]\n" + "ldr q16, [x27, #0x0]\n" + "add x26, x26, #0x20\n" + "mla v31.4s, v18.4s, v17.s[0]\n" + "mla v30.4s, v18.4s, v17.s[1]\n" + "mla v29.4s, v18.4s, v17.s[2]\n" + "mla v28.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v20.4s, v16.s[0]\n" + "fmul v18.4s, v20.4s, v16.s[1]\n" + "fmul v17.4s, v20.4s, v16.s[2]\n" + "scvtf v31.4s, v31.4s\n" + "fmul v16.4s, v20.4s, v16.s[3]\n" + "scvtf v30.4s, v30.4s\n" + "scvtf v29.4s, v29.4s\n" + "scvtf v28.4s, v28.4s\n" + "fmul v31.4s, v31.4s, v19.4s\n" + "fmul v30.4s, v30.4s, v18.4s\n" + "fmul v29.4s, v29.4s, v17.4s\n" + "fmul v28.4s, v28.4s, v16.4s\n" + "ldr q18, [x26, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x10\n" + "fadd v31.4s, v31.4s, v18.4s\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fadd v29.4s, v29.4s, v18.4s\n" + "fadd v28.4s, v28.4s, v18.4s\n" + "fmax v31.4s, v31.4s, v17.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmax v29.4s, v29.4s, v17.4s\n" + "fmax v28.4s, v28.4s, v17.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v28.4s, v28.4s, v16.4s\n" + "blt 19f\n" + "mov x20, %x[dst]\n" + "cmp x13, #0x1\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x13, #0x2\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x13, #0x3\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "str q28, [x20, #0x0]\n" + "b 22f\n" + "19:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x13, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x13, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x13, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 20f\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x21], #0x8\n" + "st1 { v30.d }[0], [x22], #0x8\n" + "st1 { v31.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v30.s }[2], [x22]\n" + "st1 { v31.s }[2], [x23]\n" + "b 21f\n" + "20:" // Row tail: Output block 0: partial_1_0 + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x21]\n" + "st1 { v30.s }[0], [x22]\n" + "st1 { v31.s }[0], [x23]\n" + "21:" // Row tail: Output block 0: Done + "22:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 16b\n" + "subs x13, x13, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x12\n" + "mov %x[dst], x24\n" + "bgt 15b\n" + "23:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", + "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..63c538be34164af3956011ba42bc297d2aa3dfa1 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h @@ -0,0 +1,136 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dxp) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 16 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4cxp) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod( + size_t n_idx, // + size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 16. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dxp) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4cxp) and packed. +/// Output tile: (rows x cols) = 16 x 4 +/// Accumulation performed in a single for loop: 32 +/// Extension used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 +/// OR kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod( + size_t m, size_t n, size_t k, + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..d0549134a5926f4fe26727a23e106f434c034dcd --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.c @@ -0,0 +1,843 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else // Architectural features check. +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 8; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v13.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 12f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v6.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v9.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "add x20, x22, x11\n" + "movi v11.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v8.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "3:" // Sub block loop + "ldr q31, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q26, [x22, #0x0]\n" + "ldr q2, [x20, #0x0]\n" + "ldr q1, [x10, #0x20]\n" + "ldr q16, [x10, #0x30]\n" + "ldr q22, [x22, #0x10]\n" + "ldr q23, [x20, #0x10]\n" + "shl v27.16b, v31.16b, #0x4\n" + "shl v19.16b, v7.16b, #0x4\n" + "ldr q29, [x10, #0x40]\n" + "ldr q25, [x10, #0x50]\n" + "and v31.16b, v31.16b, v13.16b\n" + "and v7.16b, v7.16b, v13.16b\n" + "ldr q24, [x22, #0x20]\n" + "ldr q0, [x20, #0x20]\n" + "shl v18.16b, v1.16b, #0x4\n" + "and v1.16b, v1.16b, v13.16b\n" + ".inst 0x4f9ae366 // sdot v6.4s, v27.16b, v26.4b[0]\n" + ".inst 0x4f9ae26f // sdot v15.4s, v19.16b, v26.4b[0]\n" + ".inst 0x4fbae369 // sdot v9.4s, v27.16b, v26.4b[1]\n" + ".inst 0x4fbae26c // sdot v12.4s, v19.16b, v26.4b[1]\n" + ".inst 0x4f9aeb74 // sdot v20.4s, v27.16b, v26.4b[2]\n" + ".inst 0x4f9aea7e // sdot v30.4s, v19.16b, v26.4b[2]\n" + ".inst 0x4fbaeb6b // sdot v11.4s, v27.16b, v26.4b[3]\n" + ".inst 0x4fbaea6e // sdot v14.4s, v19.16b, v26.4b[3]\n" + "ldr q26, [x10, #0x60]\n" + ".inst 0x4f82e371 // sdot v17.4s, v27.16b, v2.4b[0]\n" + ".inst 0x4f82e268 // sdot v8.4s, v19.16b, v2.4b[0]\n" + ".inst 0x4fa2e375 // sdot v21.4s, v27.16b, v2.4b[1]\n" + ".inst 0x4fa2e26a // sdot v10.4s, v19.16b, v2.4b[1]\n" + ".inst 0x4f82eb64 // sdot v4.4s, v27.16b, v2.4b[2]\n" + ".inst 0x4f82ea65 // sdot v5.4s, v19.16b, v2.4b[2]\n" + ".inst 0x4fa2eb7c // sdot v28.4s, v27.16b, v2.4b[3]\n" + "ldr q27, [x10, #0x70]\n" + ".inst 0x4fa2ea63 // sdot v3.4s, v19.16b, v2.4b[3]\n" + "ldr q2, [x22, #0x30]\n" + "ldr q19, [x20, #0x30]\n" + ".inst 0x4f96e246 // sdot v6.4s, v18.16b, v22.4b[0]\n" + ".inst 0x4fb6e249 // sdot v9.4s, v18.16b, v22.4b[1]\n" + "add x10, x10, #0x80\n" + ".inst 0x4f96ea54 // sdot v20.4s, v18.16b, v22.4b[2]\n" + ".inst 0x4fb6ea4b // sdot v11.4s, v18.16b, v22.4b[3]\n" + ".inst 0x4f97e251 // sdot v17.4s, v18.16b, v23.4b[0]\n" + ".inst 0x4fb7e255 // sdot v21.4s, v18.16b, v23.4b[1]\n" + ".inst 0x4f97ea44 // sdot v4.4s, v18.16b, v23.4b[2]\n" + ".inst 0x4fb7ea5c // sdot v28.4s, v18.16b, v23.4b[3]\n" + "shl v18.16b, v16.16b, #0x4\n" + "and v16.16b, v16.16b, v13.16b\n" + ".inst 0x4f96e24f // sdot v15.4s, v18.16b, v22.4b[0]\n" + ".inst 0x4fb6e24c // sdot v12.4s, v18.16b, v22.4b[1]\n" + ".inst 0x4f96ea5e // sdot v30.4s, v18.16b, v22.4b[2]\n" + ".inst 0x4fb6ea4e // sdot v14.4s, v18.16b, v22.4b[3]\n" + "ldr q22, [x22, #0x40]\n" + ".inst 0x4f97e248 // sdot v8.4s, v18.16b, v23.4b[0]\n" + ".inst 0x4fb7e24a // sdot v10.4s, v18.16b, v23.4b[1]\n" + ".inst 0x4f97ea45 // sdot v5.4s, v18.16b, v23.4b[2]\n" + ".inst 0x4fb7ea43 // sdot v3.4s, v18.16b, v23.4b[3]\n" + "ldr q18, [x20, #0x40]\n" + "shl v23.16b, v29.16b, #0x4\n" + "and v29.16b, v29.16b, v13.16b\n" + ".inst 0x4f98e2e6 // sdot v6.4s, v23.16b, v24.4b[0]\n" + ".inst 0x4fb8e2e9 // sdot v9.4s, v23.16b, v24.4b[1]\n" + ".inst 0x4f98eaf4 // sdot v20.4s, v23.16b, v24.4b[2]\n" + ".inst 0x4fb8eaeb // sdot v11.4s, v23.16b, v24.4b[3]\n" + ".inst 0x4f80e2f1 // sdot v17.4s, v23.16b, v0.4b[0]\n" + ".inst 0x4fa0e2f5 // sdot v21.4s, v23.16b, v0.4b[1]\n" + ".inst 0x4f80eae4 // sdot v4.4s, v23.16b, v0.4b[2]\n" + ".inst 0x4fa0eafc // sdot v28.4s, v23.16b, v0.4b[3]\n" + "shl v23.16b, v25.16b, #0x4\n" + "and v25.16b, v25.16b, v13.16b\n" + ".inst 0x4f98e2ef // sdot v15.4s, v23.16b, v24.4b[0]\n" + ".inst 0x4fb8e2ec // sdot v12.4s, v23.16b, v24.4b[1]\n" + ".inst 0x4f98eafe // sdot v30.4s, v23.16b, v24.4b[2]\n" + ".inst 0x4fb8eaee // sdot v14.4s, v23.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4f80e2e8 // sdot v8.4s, v23.16b, v0.4b[0]\n" + ".inst 0x4fa0e2ea // sdot v10.4s, v23.16b, v0.4b[1]\n" + ".inst 0x4f80eae5 // sdot v5.4s, v23.16b, v0.4b[2]\n" + ".inst 0x4fa0eae3 // sdot v3.4s, v23.16b, v0.4b[3]\n" + "ldr q23, [x20, #0x50]\n" + "shl v0.16b, v26.16b, #0x4\n" + "and v26.16b, v26.16b, v13.16b\n" + ".inst 0x4f82e006 // sdot v6.4s, v0.16b, v2.4b[0]\n" + ".inst 0x4fa2e009 // sdot v9.4s, v0.16b, v2.4b[1]\n" + ".inst 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".inst 0x4fa2e80b // sdot v11.4s, v0.16b, v2.4b[3]\n" + ".inst 0x4f93e011 // sdot v17.4s, v0.16b, v19.4b[0]\n" + ".inst 0x4fb3e015 // sdot v21.4s, v0.16b, v19.4b[1]\n" + ".inst 0x4f93e804 // sdot v4.4s, v0.16b, v19.4b[2]\n" + ".inst 0x4fb3e81c // sdot v28.4s, v0.16b, v19.4b[3]\n" + "ldr q0, [x22, #0x60]\n" + ".inst 0x4f96e3e6 // sdot v6.4s, v31.16b, v22.4b[0]\n" + ".inst 0x4fb6e3e9 // sdot v9.4s, v31.16b, v22.4b[1]\n" + ".inst 0x4f96ebf4 // sdot v20.4s, v31.16b, v22.4b[2]\n" + ".inst 0x4fb6ebeb // sdot v11.4s, v31.16b, v22.4b[3]\n" + ".inst 0x4f92e3f1 // sdot v17.4s, v31.16b, v18.4b[0]\n" + ".inst 0x4fb2e3f5 // sdot v21.4s, v31.16b, v18.4b[1]\n" + ".inst 0x4f92ebe4 // sdot v4.4s, v31.16b, v18.4b[2]\n" + ".inst 0x4fb2ebfc // sdot v28.4s, v31.16b, v18.4b[3]\n" + "ldr q31, [x20, #0x60]\n" + ".inst 0x4f98e026 // sdot v6.4s, v1.16b, v24.4b[0]\n" + ".inst 0x4fb8e029 // sdot v9.4s, v1.16b, v24.4b[1]\n" + ".inst 0x4f98e834 // sdot v20.4s, v1.16b, v24.4b[2]\n" + ".inst 0x4fb8e82b // sdot v11.4s, v1.16b, v24.4b[3]\n" + ".inst 0x4f97e031 // sdot v17.4s, v1.16b, v23.4b[0]\n" + ".inst 0x4fb7e035 // sdot v21.4s, v1.16b, v23.4b[1]\n" + ".inst 0x4f97e824 // sdot v4.4s, v1.16b, v23.4b[2]\n" + ".inst 0x4fb7e83c // sdot v28.4s, v1.16b, v23.4b[3]\n" + "ldr q1, [x22, #0x70]\n" + "add x22, x22, #0x80\n" + ".inst 0x4f80e3a6 // sdot v6.4s, v29.16b, v0.4b[0]\n" + ".inst 0x4fa0e3a9 // sdot v9.4s, v29.16b, v0.4b[1]\n" + ".inst 0x4f80ebb4 // sdot v20.4s, v29.16b, v0.4b[2]\n" + ".inst 0x4fa0ebab // sdot v11.4s, v29.16b, v0.4b[3]\n" + ".inst 0x4f9fe3b1 // sdot v17.4s, v29.16b, v31.4b[0]\n" + ".inst 0x4fbfe3b5 // sdot v21.4s, v29.16b, v31.4b[1]\n" + ".inst 0x4f9feba4 // sdot v4.4s, v29.16b, v31.4b[2]\n" + ".inst 0x4fbfebbc // sdot v28.4s, v29.16b, v31.4b[3]\n" + "ldr q29, [x20, #0x70]\n" + "add x20, x20, #0x80\n" + ".inst 0x4f81e346 // sdot v6.4s, v26.16b, v1.4b[0]\n" + ".inst 0x4fa1e349 // sdot v9.4s, v26.16b, v1.4b[1]\n" + ".inst 0x4f81eb54 // sdot v20.4s, v26.16b, v1.4b[2]\n" + ".inst 0x4fa1eb4b // sdot v11.4s, v26.16b, v1.4b[3]\n" + ".inst 0x4f9de351 // sdot v17.4s, v26.16b, v29.4b[0]\n" + ".inst 0x4fbde355 // sdot v21.4s, v26.16b, v29.4b[1]\n" + ".inst 0x4f9deb44 // sdot v4.4s, v26.16b, v29.4b[2]\n" + ".inst 0x4fbdeb5c // sdot v28.4s, v26.16b, v29.4b[3]\n" + "shl v26.16b, v27.16b, #0x4\n" + "and v27.16b, v27.16b, v13.16b\n" + ".inst 0x4f82e34f // sdot v15.4s, v26.16b, v2.4b[0]\n" + ".inst 0x4fa2e34c // sdot v12.4s, v26.16b, v2.4b[1]\n" + ".inst 0x4f82eb5e // sdot v30.4s, v26.16b, v2.4b[2]\n" + ".inst 0x4fa2eb4e // sdot v14.4s, v26.16b, v2.4b[3]\n" + ".inst 0x4f93e348 // sdot v8.4s, v26.16b, v19.4b[0]\n" + ".inst 0x4fb3e34a // sdot v10.4s, v26.16b, v19.4b[1]\n" + ".inst 0x4f93eb45 // sdot v5.4s, v26.16b, v19.4b[2]\n" + ".inst 0x4fb3eb43 // sdot v3.4s, v26.16b, v19.4b[3]\n" + ".inst 0x4f96e0ef // sdot v15.4s, v7.16b, v22.4b[0]\n" + ".inst 0x4fb6e0ec // sdot v12.4s, v7.16b, v22.4b[1]\n" + ".inst 0x4f96e8fe // sdot v30.4s, v7.16b, v22.4b[2]\n" + ".inst 0x4fb6e8ee // sdot v14.4s, v7.16b, v22.4b[3]\n" + ".inst 0x4f92e0e8 // sdot v8.4s, v7.16b, v18.4b[0]\n" + ".inst 0x4fb2e0ea // sdot v10.4s, v7.16b, v18.4b[1]\n" + ".inst 0x4f92e8e5 // sdot v5.4s, v7.16b, v18.4b[2]\n" + ".inst 0x4fb2e8e3 // sdot v3.4s, v7.16b, v18.4b[3]\n" + ".inst 0x4f98e20f // sdot v15.4s, v16.16b, v24.4b[0]\n" + ".inst 0x4fb8e20c // sdot v12.4s, v16.16b, v24.4b[1]\n" + ".inst 0x4f98ea1e // sdot v30.4s, v16.16b, v24.4b[2]\n" + ".inst 0x4fb8ea0e // sdot v14.4s, v16.16b, v24.4b[3]\n" + ".inst 0x4f97e208 // sdot v8.4s, v16.16b, v23.4b[0]\n" + ".inst 0x4fb7e20a // sdot v10.4s, v16.16b, v23.4b[1]\n" + ".inst 0x4f97ea05 // sdot v5.4s, v16.16b, v23.4b[2]\n" + ".inst 0x4fb7ea03 // sdot v3.4s, v16.16b, v23.4b[3]\n" + ".inst 0x4f80e32f // sdot v15.4s, v25.16b, v0.4b[0]\n" + ".inst 0x4fa0e32c // sdot v12.4s, v25.16b, v0.4b[1]\n" + ".inst 0x4f80eb3e // sdot v30.4s, v25.16b, v0.4b[2]\n" + ".inst 0x4fa0eb2e // sdot v14.4s, v25.16b, v0.4b[3]\n" + ".inst 0x4f9fe328 // sdot v8.4s, v25.16b, v31.4b[0]\n" + ".inst 0x4fbfe32a // sdot v10.4s, v25.16b, v31.4b[1]\n" + ".inst 0x4f9feb25 // sdot v5.4s, v25.16b, v31.4b[2]\n" + ".inst 0x4fbfeb23 // sdot v3.4s, v25.16b, v31.4b[3]\n" + ".inst 0x4f81e36f // sdot v15.4s, v27.16b, v1.4b[0]\n" + ".inst 0x4fa1e36c // sdot v12.4s, v27.16b, v1.4b[1]\n" + ".inst 0x4f81eb7e // sdot v30.4s, v27.16b, v1.4b[2]\n" + ".inst 0x4fa1eb6e // sdot v14.4s, v27.16b, v1.4b[3]\n" + ".inst 0x4f9de368 // sdot v8.4s, v27.16b, v29.4b[0]\n" + ".inst 0x4fbde36a // sdot v10.4s, v27.16b, v29.4b[1]\n" + ".inst 0x4f9deb65 // sdot v5.4s, v27.16b, v29.4b[2]\n" + ".inst 0x4fbdeb63 // sdot v3.4s, v27.16b, v29.4b[3]\n" + "bgt 3b\n" + "ldr q29, [x10, #0x0]\n" + "ldr q19, [x10, #0x10]\n" + "ld1 { v24.4s }, [x22]\n" + "ldr q1, [x10, #0x20]\n" + "add x22, x22, #0x10\n" + "ldr q2, [x10, #0x30]\n" + "ldr q31, [x22, #0x0]\n" + "add x10, x10, #0x40\n" + "mla v6.4s, v29.4s, v24.s[0]\n" + "mla v15.4s, v19.4s, v24.s[0]\n" + "mla v9.4s, v29.4s, v24.s[1]\n" + "mla v12.4s, v19.4s, v24.s[1]\n" + "mla v20.4s, v29.4s, v24.s[2]\n" + "mla v30.4s, v19.4s, v24.s[2]\n" + "mla v11.4s, v29.4s, v24.s[3]\n" + "fmul v7.4s, v1.4s, v31.s[0]\n" + "mla v14.4s, v19.4s, v24.s[3]\n" + "scvtf v6.4s, v6.4s\n" + "fmul v26.4s, v2.4s, v31.s[0]\n" + "scvtf v15.4s, v15.4s\n" + "fmul v24.4s, v1.4s, v31.s[1]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v23.4s, v2.4s, v31.s[1]\n" + "scvtf v12.4s, v12.4s\n" + "fmul v25.4s, v1.4s, v31.s[2]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v27.4s, v2.4s, v31.s[2]\n" + "scvtf v30.4s, v30.4s\n" + "fmul v22.4s, v1.4s, v31.s[3]\n" + "scvtf v11.4s, v11.4s\n" + "fmul v31.4s, v2.4s, v31.s[3]\n" + "scvtf v14.4s, v14.4s\n" + "fmul v6.4s, v6.4s, v7.4s\n" + "fmul v15.4s, v15.4s, v26.4s\n" + "fmul v9.4s, v9.4s, v24.4s\n" + "fmul v12.4s, v12.4s, v23.4s\n" + "fmul v20.4s, v20.4s, v25.4s\n" + "fmul v30.4s, v30.4s, v27.4s\n" + "fmul v11.4s, v11.4s, v22.4s\n" + "fmul v14.4s, v14.4s, v31.4s\n" + "ld1 { v25.4s }, [x20]\n" + "add x20, x20, #0x10\n" + "ldr q0, [x20, #0x0]\n" + "mla v17.4s, v29.4s, v25.s[0]\n" + "mla v8.4s, v19.4s, v25.s[0]\n" + "mla v21.4s, v29.4s, v25.s[1]\n" + "mla v10.4s, v19.4s, v25.s[1]\n" + "mla v4.4s, v29.4s, v25.s[2]\n" + "mla v5.4s, v19.4s, v25.s[2]\n" + "mla v28.4s, v29.4s, v25.s[3]\n" + "fmul v26.4s, v1.4s, v0.s[0]\n" + "mla v3.4s, v19.4s, v25.s[3]\n" + "scvtf v17.4s, v17.4s\n" + "fmul v18.4s, v2.4s, v0.s[0]\n" + "scvtf v8.4s, v8.4s\n" + "fmul v24.4s, v1.4s, v0.s[1]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v22.4s, v2.4s, v0.s[1]\n" + "scvtf v10.4s, v10.4s\n" + "fmul v27.4s, v1.4s, v0.s[2]\n" + "scvtf v4.4s, v4.4s\n" + "fmul v23.4s, v2.4s, v0.s[2]\n" + "scvtf v5.4s, v5.4s\n" + "fmul v25.4s, v1.4s, v0.s[3]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v19.4s, v2.4s, v0.s[3]\n" + "scvtf v3.4s, v3.4s\n" + "fmul v17.4s, v17.4s, v26.4s\n" + "fmul v8.4s, v8.4s, v18.4s\n" + "fmul v21.4s, v21.4s, v24.4s\n" + "fmul v10.4s, v10.4s, v22.4s\n" + "fmul v4.4s, v4.4s, v27.4s\n" + "fmul v5.4s, v5.4s, v23.4s\n" + "fmul v28.4s, v28.4s, v25.4s\n" + "fmul v3.4s, v3.4s, v19.4s\n" + "ldr q2, [x10, #0x0]\n" + "ldr q22, [x10, #0x10]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x8\n" + "ld1r { v19.4s }, [%x[clamp_vals]]\n" + "ld1r { v7.4s }, [x20]\n" + "add x10, x10, #0x20\n" + "fadd v6.4s, v6.4s, v2.4s\n" + "fadd v15.4s, v15.4s, v22.4s\n" + "fadd v9.4s, v9.4s, v2.4s\n" + "fadd v12.4s, v12.4s, v22.4s\n" + "fadd v20.4s, v20.4s, v2.4s\n" + "fadd v30.4s, v30.4s, v22.4s\n" + "fadd v11.4s, v11.4s, v2.4s\n" + "fadd v14.4s, v14.4s, v22.4s\n" + "fadd v17.4s, v17.4s, v2.4s\n" + "fadd v8.4s, v8.4s, v22.4s\n" + "fadd v21.4s, v21.4s, v2.4s\n" + "fadd v10.4s, v10.4s, v22.4s\n" + "fadd v4.4s, v4.4s, v2.4s\n" + "fadd v5.4s, v5.4s, v22.4s\n" + "fadd v28.4s, v28.4s, v2.4s\n" + "fadd v3.4s, v3.4s, v22.4s\n" + "fmax v6.4s, v6.4s, v19.4s\n" + "fmax v15.4s, v15.4s, v19.4s\n" + "fmax v9.4s, v9.4s, v19.4s\n" + "fmax v12.4s, v12.4s, v19.4s\n" + "fmax v20.4s, v20.4s, v19.4s\n" + "fmax v30.4s, v30.4s, v19.4s\n" + "fmax v11.4s, v11.4s, v19.4s\n" + "fmax v14.4s, v14.4s, v19.4s\n" + "fmax v17.4s, v17.4s, v19.4s\n" + "fmax v8.4s, v8.4s, v19.4s\n" + "fmax v21.4s, v21.4s, v19.4s\n" + "fmax v10.4s, v10.4s, v19.4s\n" + "fmax v4.4s, v4.4s, v19.4s\n" + "fmax v5.4s, v5.4s, v19.4s\n" + "fmax v28.4s, v28.4s, v19.4s\n" + "fmax v3.4s, v3.4s, v19.4s\n" + "fmin v6.4s, v6.4s, v7.4s\n" + "fmin v15.4s, v15.4s, v7.4s\n" + "fmin v9.4s, v9.4s, v7.4s\n" + "fmin v12.4s, v12.4s, v7.4s\n" + "fmin v20.4s, v20.4s, v7.4s\n" + "fmin v30.4s, v30.4s, v7.4s\n" + "fmin v11.4s, v11.4s, v7.4s\n" + "fmin v14.4s, v14.4s, v7.4s\n" + "fmin v17.4s, v17.4s, v7.4s\n" + "fmin v8.4s, v8.4s, v7.4s\n" + "fmin v21.4s, v21.4s, v7.4s\n" + "fmin v10.4s, v10.4s, v7.4s\n" + "fmin v4.4s, v4.4s, v7.4s\n" + "fmin v5.4s, v5.4s, v7.4s\n" + "fmin v28.4s, v28.4s, v7.4s\n" + "fmin v3.4s, v3.4s, v7.4s\n" + "blt 6f\n" + "mov x20, %x[dst]\n" + "str q6, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q20, [x20, #0x0]\n" + "str q30, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q11, [x20, #0x0]\n" + "str q14, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q17, [x20, #0x0]\n" + "str q8, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "str q10, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q4, [x20, #0x0]\n" + "str q5, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "str q3, [x20, #0x10]\n" + "b 11f\n" + "6:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #2, 8f\n" + "st1 { v28.4s }, [x23], #0x10\n" + "st1 { v4.4s }, [x25], #0x10\n" + "st1 { v21.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x26], #0x10\n" + "st1 { v11.4s }, [x20], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x21], #0x10\n" + "st1 { v6.4s }, [x27], #0x10\n" + "tbz x9, #1, 7f\n" + "st1 { v3.d }[0], [x23], #0x8\n" + "st1 { v5.d }[0], [x25], #0x8\n" + "st1 { v10.d }[0], [x24], #0x8\n" + "st1 { v8.d }[0], [x26], #0x8\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v30.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x21], #0x8\n" + "st1 { v15.d }[0], [x27], #0x8\n" + "tbz x9, #0, 10f\n" + "st1 { v3.s }[2], [x23]\n" + "st1 { v5.s }[2], [x25]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v30.s }[2], [x22]\n" + "st1 { v12.s }[2], [x21]\n" + "st1 { v15.s }[2], [x27]\n" + "b 10f\n" + "7:" // Output block 0: partial_1_4 + "tbz x9, #0, 10f\n" + "st1 { v3.s }[0], [x23]\n" + "st1 { v5.s }[0], [x25]\n" + "st1 { v10.s }[0], [x24]\n" + "st1 { v8.s }[0], [x26]\n" + "st1 { v14.s }[0], [x20]\n" + "st1 { v30.s }[0], [x22]\n" + "st1 { v12.s }[0], [x21]\n" + "st1 { v15.s }[0], [x27]\n" + "b 10f\n" + "8:" // Output block 0: partial_2_0 + "tbz x9, #1, 9f\n" + "st1 { v28.d }[0], [x23], #0x8\n" + "st1 { v4.d }[0], [x25], #0x8\n" + "st1 { v21.d }[0], [x24], #0x8\n" + "st1 { v17.d }[0], [x26], #0x8\n" + "st1 { v11.d }[0], [x20], #0x8\n" + "st1 { v20.d }[0], [x22], #0x8\n" + "st1 { v9.d }[0], [x21], #0x8\n" + "st1 { v6.d }[0], [x27], #0x8\n" + "tbz x9, #0, 10f\n" + "st1 { v28.s }[2], [x23]\n" + "st1 { v4.s }[2], [x25]\n" + "st1 { v21.s }[2], [x24]\n" + "st1 { v17.s }[2], [x26]\n" + "st1 { v11.s }[2], [x20]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v9.s }[2], [x21]\n" + "st1 { v6.s }[2], [x27]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v28.s }[0], [x23]\n" + "st1 { v4.s }[0], [x25]\n" + "st1 { v21.s }[0], [x24]\n" + "st1 { v17.s }[0], [x26]\n" + "st1 { v11.s }[0], [x20]\n" + "st1 { v20.s }[0], [x22]\n" + "st1 { v9.s }[0], [x21]\n" + "st1 { v6.s }[0], [x27]\n" + "10:" // Output block 0: Done + "11:" // Output stage exit + "subs x9, x9, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "12:" // Row loop skip + "cbz x12, 23f\n" + "13:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "14:" // Row tail: Column loop + "mov x22, %x[lhs_packed]\n" + "movi v6.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v9.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "15:" // Row tail: Sub block loop + "ldr q10, [x26, #0x0]\n" + "ldr q8, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q7, [x22, #0x0]\n" + "ldr q5, [x26, #0x20]\n" + "ldr q4, [x26, #0x30]\n" + "ldr q3, [x22, #0x10]\n" + "ldr q17, [x26, #0x40]\n" + "ldr q1, [x26, #0x50]\n" + "shl v29.16b, v10.16b, #0x4\n" + "shl v18.16b, v8.16b, #0x4\n" + "ldr q2, [x22, #0x20]\n" + "ldr q31, [x26, #0x60]\n" + "shl v27.16b, v5.16b, #0x4\n" + "and v10.16b, v10.16b, v13.16b\n" + "ldr q0, [x26, #0x70]\n" + "ldr q28, [x22, #0x30]\n" + "shl v26.16b, v4.16b, #0x4\n" + "and v8.16b, v8.16b, v13.16b\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4f87e3a6 // sdot v6.4s, v29.16b, v7.4b[0]\n" + ".inst 0x4f87e24f // sdot v15.4s, v18.16b, v7.4b[0]\n" + "ldr q23, [x22, #0x60]\n" + "ldr q22, [x22, #0x70]\n" + ".inst 0x4fa7e3a9 // sdot v9.4s, v29.16b, v7.4b[1]\n" + ".inst 0x4fa7e24c // sdot v12.4s, v18.16b, v7.4b[1]\n" + ".inst 0x4f87ebb4 // sdot v20.4s, v29.16b, v7.4b[2]\n" + ".inst 0x4f87ea5e // sdot v30.4s, v18.16b, v7.4b[2]\n" + "shl v21.16b, v17.16b, #0x4\n" + "add x26, x26, #0x80\n" + ".inst 0x4fa7ebab // sdot v11.4s, v29.16b, v7.4b[3]\n" + ".inst 0x4fa7ea4e // sdot v14.4s, v18.16b, v7.4b[3]\n" + "shl v29.16b, v1.16b, #0x4\n" + "add x22, x22, #0x80\n" + ".inst 0x4f83e366 // sdot v6.4s, v27.16b, v3.4b[0]\n" + ".inst 0x4f83e34f // sdot v15.4s, v26.16b, v3.4b[0]\n" + "shl v19.16b, v31.16b, #0x4\n" + ".inst 0x4fa3e369 // sdot v9.4s, v27.16b, v3.4b[1]\n" + ".inst 0x4fa3e34c // sdot v12.4s, v26.16b, v3.4b[1]\n" + "shl v18.16b, v0.16b, #0x4\n" + ".inst 0x4f83eb74 // sdot v20.4s, v27.16b, v3.4b[2]\n" + ".inst 0x4f83eb5e // sdot v30.4s, v26.16b, v3.4b[2]\n" + "and v5.16b, v5.16b, v13.16b\n" + ".inst 0x4fa3eb6b // sdot v11.4s, v27.16b, v3.4b[3]\n" + ".inst 0x4fa3eb4e // sdot v14.4s, v26.16b, v3.4b[3]\n" + "and v4.16b, v4.16b, v13.16b\n" + ".inst 0x4f82e2a6 // sdot v6.4s, v21.16b, v2.4b[0]\n" + ".inst 0x4f82e3af // sdot v15.4s, v29.16b, v2.4b[0]\n" + "and v17.16b, v17.16b, v13.16b\n" + ".inst 0x4fa2e2a9 // sdot v9.4s, v21.16b, v2.4b[1]\n" + ".inst 0x4fa2e3ac // sdot v12.4s, v29.16b, v2.4b[1]\n" + "and v1.16b, v1.16b, v13.16b\n" + ".inst 0x4f82eab4 // sdot v20.4s, v21.16b, v2.4b[2]\n" + ".inst 0x4f82ebbe // sdot v30.4s, v29.16b, v2.4b[2]\n" + "and v31.16b, v31.16b, v13.16b\n" + ".inst 0x4fa2eaab // sdot v11.4s, v21.16b, v2.4b[3]\n" + ".inst 0x4fa2ebae // sdot v14.4s, v29.16b, v2.4b[3]\n" + "and v0.16b, v0.16b, v13.16b\n" + ".inst 0x4f9ce266 // sdot v6.4s, v19.16b, v28.4b[0]\n" + ".inst 0x4f9ce24f // sdot v15.4s, v18.16b, v28.4b[0]\n" + ".inst 0x4fbce269 // sdot v9.4s, v19.16b, v28.4b[1]\n" + ".inst 0x4fbce24c // sdot v12.4s, v18.16b, v28.4b[1]\n" + ".inst 0x4f9cea74 // sdot v20.4s, v19.16b, v28.4b[2]\n" + ".inst 0x4f9cea5e // sdot v30.4s, v18.16b, v28.4b[2]\n" + ".inst 0x4fbcea6b // sdot v11.4s, v19.16b, v28.4b[3]\n" + ".inst 0x4fbcea4e // sdot v14.4s, v18.16b, v28.4b[3]\n" + ".inst 0x4f99e146 // sdot v6.4s, v10.16b, v25.4b[0]\n" + ".inst 0x4f99e10f // sdot v15.4s, v8.16b, v25.4b[0]\n" + ".inst 0x4fb9e149 // sdot v9.4s, v10.16b, v25.4b[1]\n" + ".inst 0x4fb9e10c // sdot v12.4s, v8.16b, v25.4b[1]\n" + ".inst 0x4f99e954 // sdot v20.4s, v10.16b, v25.4b[2]\n" + ".inst 0x4f99e91e // sdot v30.4s, v8.16b, v25.4b[2]\n" + ".inst 0x4fb9e94b // sdot v11.4s, v10.16b, v25.4b[3]\n" + ".inst 0x4fb9e90e // sdot v14.4s, v8.16b, v25.4b[3]\n" + ".inst 0x4f98e0a6 // sdot v6.4s, v5.16b, v24.4b[0]\n" + ".inst 0x4f98e08f // sdot v15.4s, v4.16b, v24.4b[0]\n" + ".inst 0x4fb8e0a9 // sdot v9.4s, v5.16b, v24.4b[1]\n" + ".inst 0x4fb8e08c // sdot v12.4s, v4.16b, v24.4b[1]\n" + ".inst 0x4f98e8b4 // sdot v20.4s, v5.16b, v24.4b[2]\n" + ".inst 0x4f98e89e // sdot v30.4s, v4.16b, v24.4b[2]\n" + ".inst 0x4fb8e8ab // sdot v11.4s, v5.16b, v24.4b[3]\n" + ".inst 0x4fb8e88e // sdot v14.4s, v4.16b, v24.4b[3]\n" + ".inst 0x4f97e226 // sdot v6.4s, v17.16b, v23.4b[0]\n" + ".inst 0x4f97e02f // sdot v15.4s, v1.16b, v23.4b[0]\n" + ".inst 0x4fb7e229 // sdot v9.4s, v17.16b, v23.4b[1]\n" + ".inst 0x4fb7e02c // sdot v12.4s, v1.16b, v23.4b[1]\n" + ".inst 0x4f97ea34 // sdot v20.4s, v17.16b, v23.4b[2]\n" + ".inst 0x4f97e83e // sdot v30.4s, v1.16b, v23.4b[2]\n" + ".inst 0x4fb7ea2b // sdot v11.4s, v17.16b, v23.4b[3]\n" + ".inst 0x4fb7e82e // sdot v14.4s, v1.16b, v23.4b[3]\n" + ".inst 0x4f96e3e6 // sdot v6.4s, v31.16b, v22.4b[0]\n" + ".inst 0x4f96e00f // sdot v15.4s, v0.16b, v22.4b[0]\n" + ".inst 0x4fb6e3e9 // sdot v9.4s, v31.16b, v22.4b[1]\n" + ".inst 0x4fb6e00c // sdot v12.4s, v0.16b, v22.4b[1]\n" + ".inst 0x4f96ebf4 // sdot v20.4s, v31.16b, v22.4b[2]\n" + ".inst 0x4f96e81e // sdot v30.4s, v0.16b, v22.4b[2]\n" + ".inst 0x4fb6ebeb // sdot v11.4s, v31.16b, v22.4b[3]\n" + ".inst 0x4fb6e80e // sdot v14.4s, v0.16b, v22.4b[3]\n" + "bgt 15b\n" + "ldr q21, [x26, #0x0]\n" + "ldr q4, [x26, #0x10]\n" + "ld1 { v19.4s }, [x22]\n" + "ldr q25, [x26, #0x20]\n" + "add x22, x22, #0x10\n" + "ldr q24, [x26, #0x30]\n" + "ldr q18, [x22, #0x0]\n" + "add x26, x26, #0x40\n" + "mla v6.4s, v21.4s, v19.s[0]\n" + "mla v15.4s, v4.4s, v19.s[0]\n" + "mla v9.4s, v21.4s, v19.s[1]\n" + "mla v12.4s, v4.4s, v19.s[1]\n" + "mla v20.4s, v21.4s, v19.s[2]\n" + "mla v30.4s, v4.4s, v19.s[2]\n" + "mla v11.4s, v21.4s, v19.s[3]\n" + "fmul v28.4s, v25.4s, v18.s[0]\n" + "mla v14.4s, v4.4s, v19.s[3]\n" + "scvtf v6.4s, v6.4s\n" + "fmul v22.4s, v24.4s, v18.s[0]\n" + "scvtf v15.4s, v15.4s\n" + "fmul v21.4s, v25.4s, v18.s[1]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v1.4s, v24.4s, v18.s[1]\n" + "scvtf v12.4s, v12.4s\n" + "fmul v19.4s, v25.4s, v18.s[2]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v24.4s, v18.s[2]\n" + "scvtf v30.4s, v30.4s\n" + "fmul v23.4s, v25.4s, v18.s[3]\n" + "scvtf v11.4s, v11.4s\n" + "fmul v2.4s, v24.4s, v18.s[3]\n" + "scvtf v14.4s, v14.4s\n" + "fmul v6.4s, v6.4s, v28.4s\n" + "fmul v15.4s, v15.4s, v22.4s\n" + "fmul v9.4s, v9.4s, v21.4s\n" + "fmul v12.4s, v12.4s, v1.4s\n" + "fmul v20.4s, v20.4s, v19.4s\n" + "fmul v30.4s, v30.4s, v10.4s\n" + "fmul v11.4s, v11.4s, v23.4s\n" + "fmul v14.4s, v14.4s, v2.4s\n" + "ldr q19, [x26, #0x0]\n" + "ldr q18, [x26, #0x10]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x8\n" + "ld1r { v25.4s }, [%x[clamp_vals]]\n" + "ld1r { v26.4s }, [x20]\n" + "add x26, x26, #0x20\n" + "fadd v6.4s, v6.4s, v19.4s\n" + "fadd v15.4s, v15.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v19.4s\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fadd v20.4s, v20.4s, v19.4s\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fadd v11.4s, v11.4s, v19.4s\n" + "fadd v14.4s, v14.4s, v18.4s\n" + "fmax v6.4s, v6.4s, v25.4s\n" + "fmax v15.4s, v15.4s, v25.4s\n" + "fmax v9.4s, v9.4s, v25.4s\n" + "fmax v12.4s, v12.4s, v25.4s\n" + "fmax v20.4s, v20.4s, v25.4s\n" + "fmax v30.4s, v30.4s, v25.4s\n" + "fmax v11.4s, v11.4s, v25.4s\n" + "fmax v14.4s, v14.4s, v25.4s\n" + "fmin v6.4s, v6.4s, v26.4s\n" + "fmin v15.4s, v15.4s, v26.4s\n" + "fmin v9.4s, v9.4s, v26.4s\n" + "fmin v12.4s, v12.4s, v26.4s\n" + "fmin v20.4s, v20.4s, v26.4s\n" + "fmin v30.4s, v30.4s, v26.4s\n" + "fmin v11.4s, v11.4s, v26.4s\n" + "fmin v14.4s, v14.4s, v26.4s\n" + "blt 17f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q6, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x12, #0x2\n" + "str q9, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x12, #0x3\n" + "str q20, [x20, #0x0]\n" + "str q30, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "str q11, [x20, #0x0]\n" + "str q14, [x20, #0x10]\n" + "b 22f\n" + "17:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #2, 19f\n" + "st1 { v11.4s }, [x20], #0x10\n" + "st1 { v20.4s }, [x21], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v6.4s }, [x23], #0x10\n" + "tbz x25, #1, 18f\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v30.d }[0], [x21], #0x8\n" + "st1 { v12.d }[0], [x22], #0x8\n" + "st1 { v15.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v12.s }[2], [x22]\n" + "st1 { v15.s }[2], [x23]\n" + "b 21f\n" + "18:" // Row tail: Output block 0: partial_1_4 + "tbz x25, #0, 21f\n" + "st1 { v14.s }[0], [x20]\n" + "st1 { v30.s }[0], [x21]\n" + "st1 { v12.s }[0], [x22]\n" + "st1 { v15.s }[0], [x23]\n" + "b 21f\n" + "19:" // Row tail: Output block 0: partial_2_0 + "tbz x25, #1, 20f\n" + "st1 { v11.d }[0], [x20], #0x8\n" + "st1 { v20.d }[0], [x21], #0x8\n" + "st1 { v9.d }[0], [x22], #0x8\n" + "st1 { v6.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v11.s }[2], [x20]\n" + "st1 { v20.s }[2], [x21]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v6.s }[2], [x23]\n" + "b 21f\n" + "20:" // Row tail: Output block 0: partial_1_0 + "st1 { v11.s }[0], [x20]\n" + "st1 { v20.s }[0], [x21]\n" + "st1 { v9.s }[0], [x22]\n" + "st1 { v6.s }[0], [x23]\n" + "21:" // Row tail: Output block 0: Done + "22:" // Row tail: Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 14b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 13b\n" + "23:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..6e77b84ab9d3e89751b746cf52fe03ab4cb5e15e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h @@ -0,0 +1,136 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dxp) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4cxp) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod( + size_t n_idx, // + size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dxp) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4cxp) and packed. +/// Output tile: (rows x cols) = 8 x 8 +/// Accumulation performed in a single for loop: 32 +/// Extension used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 +/// OR kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod( + size_t m, size_t n, size_t k, + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index fabb6611e9bbf3c890840e18a817e531327dcd90..53aaf00b87a24e39e6a1bc46b89ac2e33e941d03 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -15,8 +15,10 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" @@ -34,10 +36,12 @@ namespace kai::test { -static const std::array, 6> +static const std::array, 8> variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp = {{ UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, cpu_has_dotprod), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, cpu_has_dotprod), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod, cpu_has_dotprod), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x4_8x8x32_neon_dotprod, cpu_has_dotprod), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, cpu_has_i8mm), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, cpu_has_i8mm), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, cpu_has_i8mm),